Skip to content

Commit 2e78366

Browse files
committed
Enable to output LoD in fetch_op and check output LoD in the op unit test.
1 parent fa72e54 commit 2e78366

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

paddle/operators/fetch_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase {
5252
// FIXME(yuyang18): Should we assume the fetch operator always generate
5353
// CPU outputs?
5454
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
55+
dst_item.set_lod(src_item.lod());
5556

5657
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
5758
}

python/paddle/v2/framework/tests/op_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,20 +333,31 @@ def find_actual(target_name, fetch_list):
333333
type(sub_out))
334334
for sub_out_name, expect in sub_out:
335335
idx = find_actual(sub_out_name, fetch_list)
336-
actual = outs[idx]
336+
actual_t = np.array(outs[idx])
337+
expect_t = expect[0] \
338+
if isinstance(expect, tuple) else expect
337339
self.assertTrue(
338340
np.allclose(
339-
actual, expect, atol=atol),
341+
actual_t, expect_t, atol=atol),
340342
"Output (" + sub_out_name + ") has diff at " +
341343
str(place))
344+
if isinstance(expect, tuple):
345+
self.assertListEqual(
346+
actual_t.lod(), expect[1], "Output (" + sub_out_name
347+
+ ") has different lod at " + str(place))
342348
else:
343349
idx = find_actual(out_name, fetch_list)
344-
actual = outs[idx]
350+
actual_t = outs[idx]
345351
expect = self.outputs[out_name]
352+
expect_t = expect[0] if isinstance(expect, tuple) else expect
346353
self.assertTrue(
347354
np.allclose(
348-
actual, expect, atol=atol),
355+
actual_t, expect_t, atol=atol),
349356
"Output (" + out_name + ") has diff at " + str(place))
357+
if isinstance(expect, tuple):
358+
self.assertListEqual(actual_t.lod(), expect[1],
359+
"Output (" + out_name +
360+
") has different lod at " + str(place))
350361

351362
def check_output(self, atol=1e-5):
352363
places = [core.CPUPlace()]

python/paddle/v2/framework/tests/test_lstm_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ def setUp(self):
155155
'Weight': w,
156156
'Bias': b
157157
}
158-
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
158+
self.outputs = {
159+
'Hidden': (h, self.lod),
160+
'Cell': (c, self.lod),
161+
'BatchGate': g_sort
162+
}
159163
self.attrs = {
160164
'usePeepholes': True,
161165
'isReverse': self.is_reverse,

0 commit comments

Comments
 (0)