Skip to content

Commit 1e9b10e

Browse files
committed
Binding methods for class Tensor and Variable
1 parent eabc0d4 commit 1e9b10e

File tree

5 files changed

+84
-146
lines changed

5 files changed

+84
-146
lines changed

python/paddle/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import paddle.dataset
3232
import paddle.batch
3333
batch = batch.batch
34-
import paddle.tensor
3534
from .fluid import monkey_patch_variable
3635
from .fluid.dygraph import monkey_patch_math_varbase
3736
monkey_patch_variable()
@@ -42,7 +41,7 @@
4241
import paddle.compat
4342
import paddle.distributed
4443
import paddle.sysconfig
45-
44+
import paddle.tensor
4645
import paddle.nn
4746
import paddle.distributed.fleet
4847
import paddle.optimizer

python/paddle/fluid/dygraph/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
from . import amp
6060
from .amp import *
6161

62+
from .math_op_patch import monkey_patch_math_varbase
63+
6264
__all__ = []
6365
__all__ += layers.__all__
6466
__all__ += base.__all__

python/paddle/fluid/dygraph/math_op_patch.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,36 @@ def __impl__(self):
246246
('ndim', _ndim_),
247247
('size', lambda x: x.shape),
248248
# Type2: From Template that create core.ops automatically. It's recommended.
249-
('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
249+
('__add__',
250+
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
250251
## a+b == b+a. Do not need to reverse explicitly
251-
('__radd__', _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
252-
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_)),
253-
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True, _scalar_rsub_)),
254-
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
252+
('__radd__',
253+
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
254+
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
255+
_scalar_sub_)),
256+
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
257+
_scalar_rsub_)),
258+
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
259+
_scalar_mul_)),
255260
## a*b == b*a. Do not need to reverse explicitly
256-
('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
257-
('__div__', _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_)),
258-
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', False, _scalar_div_)),
259-
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, None)),
260-
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
261-
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)),
262-
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)),
263-
('__floordiv__', _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
264-
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, None)),
261+
('__rmul__',
262+
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
263+
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
264+
_scalar_div_)),
265+
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
266+
False, _scalar_div_)),
267+
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
268+
None)),
269+
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
270+
None)),
271+
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
272+
None)),
273+
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
274+
None)),
275+
('__floordiv__', _binary_creator_('__floordiv__',
276+
'elementwise_floordiv', False, None)),
277+
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
278+
None)),
265279
## for logical compare
266280
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
267281
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
@@ -300,13 +314,13 @@ def __impl__(self):
300314
if not _already_patch_varbase:
301315
for method in varbase_methods:
302316
method_name = method[0]
303-
method_impl = method[0]
317+
method_impl = method[1]
304318
setattr(core.VarBase, method_name, method_impl)
305319
else:
306320
import paddle.tensor
307321
for method_name in common_methods:
308322
if hasattr(core.VarBase, method_name): continue
309-
method_impl = getattr(paddle.tensor, method, None)
323+
method_impl = getattr(paddle.tensor, method_name, None)
310324
if method_impl: setattr(core.VarBase, method_name, method_impl)
311325

312326
_already_patch_varbase = True

python/paddle/fluid/layers/math_op_patch.py

Lines changed: 49 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -54,118 +54,26 @@
5454
"__ge__": "A >= B"
5555
}
5656

57+
# method for Tensor from paddle.tensor
58+
# edit it when paddle.tensor has new method about Tensor operation
5759
common_methods = [
58-
'exp',
59-
'tanh',
60-
'atan',
61-
'sqrt',
62-
'rsqrt',
63-
'abs',
64-
'ceil',
65-
'floor',
66-
'cos',
67-
'acos',
68-
'asin',
69-
'sin',
70-
'sinh',
71-
'cosh',
72-
'round',
73-
'reciprocal',
74-
'square',
75-
'rank',
76-
'matmul',
77-
'dot',
78-
'norm',
79-
'transpose',
80-
'dist',
81-
't',
82-
'cross',
83-
'cholesky',
84-
'bmm',
85-
'histogram',
86-
'equal',
87-
'greater_equal',
88-
'greater_than',
89-
'is_empty',
90-
'isfinite',
91-
'less_equal',
92-
'less_than',
93-
'logical_and',
94-
'logical_not',
95-
'logical_or',
96-
'logical_xor',
97-
'not_equal',
98-
'reduce_all',
99-
'reduce_any',
100-
'allclose',
101-
'equal_all',
102-
'cast',
103-
'expand',
104-
'expand_as',
105-
'tile',
106-
'flatten',
107-
'gather',
108-
'gather_nd',
109-
'reshape',
110-
'reverse',
111-
'scatter',
112-
'scatter_nd_add',
113-
'scatter_nd',
114-
'shard_index',
115-
'slice',
116-
'split',
117-
'squeeze',
118-
'strided_slice',
119-
'unique',
120-
'unique_with_counts',
121-
'unsqueeze',
122-
'flip',
123-
'unbind',
124-
'roll',
125-
'cumsum',
126-
'increment',
127-
'log',
128-
'pow',
129-
'reciprocal',
130-
'round',
131-
'rsqrt',
132-
'scale',
133-
'sign',
134-
'stanh',
135-
'sum',
136-
'reduce_prod',
137-
'max',
138-
'min',
139-
'mm',
140-
'div',
141-
'multiply',
142-
'add',
143-
'logsumexp',
144-
'log1p',
145-
'erf',
146-
'addcmul',
147-
'addmm',
148-
'clamp',
149-
'trace',
150-
'kron',
151-
'argmax',
152-
'argmin',
153-
'argsort',
154-
'has_inf',
155-
'has_nan',
156-
'topk',
157-
'index_select',
158-
'nonzero',
159-
'sort',
160-
'index_sample',
161-
'mean',
162-
'std',
163-
'var',
164-
'elementwise_add',
165-
'elementwise_div',
166-
'elementwise_floordiv',
167-
'elementwise_mod',
168-
'elementwise_pow',
60+
'exp', 'tanh', 'atan', 'sqrt', 'rsqrt', 'abs', 'ceil', 'floor', 'cos',
61+
'acos', 'asin', 'sin', 'sinh', 'cosh', 'round', 'reciprocal', 'square',
62+
'rank', 'matmul', 'dot', 'norm', 'transpose', 'dist', 't', 'cross',
63+
'cholesky', 'bmm', 'histogram', 'equal', 'greater_equal', 'greater_than',
64+
'is_empty', 'isfinite', 'less_equal', 'less_than', 'logical_and',
65+
'logical_not', 'logical_or', 'logical_xor', 'not_equal', 'reduce_all',
66+
'reduce_any', 'allclose', 'equal_all', 'cast', 'expand', 'expand_as',
67+
'tile', 'flatten', 'gather', 'gather_nd', 'reshape', 'reverse', 'scatter',
68+
'scatter_nd_add', 'scatter_nd', 'shard_index', 'slice', 'split', 'squeeze',
69+
'strided_slice', 'unique', 'unique_with_counts', 'unsqueeze', 'flip',
70+
'unbind', 'roll', 'cumsum', 'increment', 'log', 'pow', 'reciprocal',
71+
'round', 'rsqrt', 'scale', 'sign', 'stanh', 'sum', 'reduce_prod', 'max',
72+
'min', 'mm', 'div', 'multiply', 'add', 'logsumexp', 'log1p', 'erf',
73+
'addcmul', 'addmm', 'clamp', 'trace', 'kron', 'argmax', 'argmin', 'argsort',
74+
'has_inf', 'has_nan', 'topk', 'index_select', 'nonzero', 'sort',
75+
'index_sample', 'mean', 'std', 'var', 'elementwise_add', 'elementwise_div',
76+
'elementwise_floordiv', 'elementwise_mod', 'elementwise_pow',
16977
'elementwise_sub'
17078
]
17179

@@ -417,22 +325,36 @@ def __impl__(self, other_var):
417325
# b=-a
418326
('__neg__', _neg_),
419327
('astype', astype),
420-
('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
328+
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
329+
_scalar_add_)),
421330
# a+b == b+a. Do not need to reverse explicitly
422-
('__radd__', _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
423-
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_)),
424-
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True, _scalar_rsub_)),
425-
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
331+
('__radd__',
332+
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
333+
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
334+
_scalar_sub_)),
335+
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
336+
_scalar_rsub_)),
337+
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
338+
_scalar_mul_)),
426339
# a*b == b*a. Do not need to reverse explicitly
427-
('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
428-
('__div__', _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_)),
429-
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', False, _scalar_div_)),
430-
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, None)),
431-
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
432-
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)),
433-
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)),
434-
('__floordiv__', _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
435-
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, None)),
340+
('__rmul__',
341+
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
342+
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
343+
_scalar_div_)),
344+
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
345+
False, _scalar_div_)),
346+
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
347+
None)),
348+
('__rtruediv__', _binary_creator_('__rtruediv__', 'elementwise_div',
349+
True, None)),
350+
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
351+
None)),
352+
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
353+
None)),
354+
('__floordiv__', _binary_creator_('__floordiv__',
355+
'elementwise_floordiv', False, None)),
356+
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
357+
None)),
436358
# for logical compare
437359
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
438360
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
@@ -451,8 +373,8 @@ def __impl__(self, other_var):
451373
else:
452374
import paddle.tensor
453375
for method_name in common_methods:
454-
if hasattr(Variable, method): continue
455-
method_impl = getattr(paddle.tensor, method, None)
376+
if hasattr(Variable, method_name): continue
377+
method_impl = getattr(paddle.tensor, method_name, None)
456378
if method_impl: setattr(Variable, method_name, method_impl)
457379

458380
_already_patch_variable = True

python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,10 @@ def test_tensor_patch_method(self):
467467
self.assertTrue(
468468
np.array_equal(
469469
x.allclose(y).numpy(), paddle.allclose(x, y).numpy()))
470+
m = x.expand([2, 2, 3])
470471
self.assertTrue(
471472
np.array_equal(
472-
x.expand_as(z).numpy(), paddle.expand_as(x, z).numpy()))
473+
x.expand_as(m).numpy(), paddle.expand_as(x, m).numpy()))
473474
index = paddle.to_tensor([2, 1, 0])
474475
self.assertTrue(
475476
np.array_equal(

0 commit comments

Comments
 (0)