Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/operators/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class WhileGradOp : public framework::OperatorBase {

void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
Expand Down Expand Up @@ -205,6 +208,8 @@ class WhileGradOp : public framework::OperatorBase {
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
dev_ctx.Wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the cost of adding this dev_ctx.Wait();?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not make detailed test yet. But if we do not delete the step scope, we can not training a larger RNN model because of OutOfMemory.

const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!
This fix will solve the headache OOM.

}
}
};
Expand Down
24 changes: 17 additions & 7 deletions python/paddle/v2/fluid/memory_optimization_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num):
def __init__(self, Program, ops, forward_num, skip_opt):
self._program = Program
self._ops = ops
self._forward_num = forward_num
Expand All @@ -41,6 +41,7 @@ def __init__(self, Program, ops, forward_num):
self._defs = defaultdict(set)
self._live_in = defaultdict(set)
self._live_out = defaultdict(set)
self._skip_opt = skip_opt

def _add_connections(self, connections):
for node1, node2 in connections:
Expand Down Expand Up @@ -130,6 +131,10 @@ def check_var_validity(block_desc, x, is_forward):
block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False
if x in self._skip_opt:
return False
if not self._find_var(block_desc, x, is_forward).shape():
return False
return True

self._build_graph()
Expand All @@ -140,6 +145,7 @@ def check_var_validity(block_desc, x, is_forward):
if op.type() == "while" or op.type() == "while_grad":
continue
block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num
if self.pool:
defs_can_optimize = filter(
Expand Down Expand Up @@ -197,28 +203,32 @@ def get_cfgs(input_program):
block_desc = pdesc.block(0)
op_size = block_desc.op_size()
# Get global block ops
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size))
ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))

while_sub_block_ids = []
while_grad_sub_block_ids = []
while_pair = []
while_op_output = set()
while_block_id_pair = []

for i in range(op_size):
op = block_desc.op(i)
if op.type() == "while":
while_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())
elif op.type() == "while_grad":
while_grad_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())

# Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids:
while_pair.append((parent_id, grad_id))
while_block_id_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id)

# Get while/while_grad block ops
for parent_id, grad_id in while_pair:
for parent_id, grad_id in while_block_id_pair:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block_op_size = while_block.op_size()
Expand All @@ -230,7 +240,7 @@ def get_cfgs(input_program):
for i in range(while_grad_block_op_size):
while_block_ops.append(while_grad_block.op(i))

ops_list.append((while_block_ops, while_block_op_size))
ops_list.append((while_block_ops, while_block_op_size, while_op_output))

# Process rest while block ops
for parent_id in while_sub_block_ids:
Expand All @@ -242,7 +252,7 @@ def get_cfgs(input_program):

ops_list.append((while_block_ops, while_block_op_size))

cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list]
cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list]
return cfgs


Expand Down