5454 "__ge__" : "A >= B" 
5555}
5656
57+ # method for Tensor from paddle.tensor 
58+ # edit it when paddle.tensor has new method about Tensor operation 
5759common_methods  =  [
58-  'exp' ,
59-  'tanh' , 
60-  'atan' , 
61-  'sqrt' , 
62-  'rsqrt' , 
63-  'abs' , 
64-  'ceil' , 
65-  'floor' , 
66-  'cos' ,
67-  'acos' , 
68-  'asin' , 
69-  'sin' , 
70-  'sinh' , 
71-  'cosh' , 
72-  'round' , 
73-  'reciprocal' , 
74-  'square' ,
75-  'rank' , 
76-  'matmul' , 
77-  'dot' , 
78-  'norm' , 
79-  'transpose' , 
80-  'dist' , 
81-  't' , 
82-  'cross' ,
83-  'cholesky' , 
84-  'bmm' , 
85-  'histogram' , 
86-  'equal' , 
87-  'greater_equal' , 
88-  'greater_than' ,
89-  'is_empty' , 
90-  'isfinite' , 
91-  'less_equal' , 
92-  'less_than' , 
93-  'logical_and' ,
94-  'logical_not' , 
95-  'logical_or' , 
96-  'logical_xor' , 
97-  'not_equal' , 
98-  'reduce_all' ,
99-  'reduce_any' , 
100-  'allclose' , 
101-  'equal_all' , 
102-  'cast' , 
103-  'expand' , 
104-  'expand_as' ,
105-  'tile' , 
106-  'flatten' , 
107-  'gather' , 
108-  'gather_nd' , 
109-  'reshape' , 
110-  'reverse' , 
111-  'scatter' ,
112-  'scatter_nd_add' , 
113-  'scatter_nd' , 
114-  'shard_index' , 
115-  'slice' , 
116-  'split' , 
117-  'squeeze' ,
118-  'strided_slice' , 
119-  'unique' , 
120-  'unique_with_counts' , 
121-  'unsqueeze' , 
122-  'flip' ,
123-  'unbind' , 
124-  'roll' , 
125-  'cumsum' , 
126-  'increment' , 
127-  'log' , 
128-  'pow' , 
129-  'reciprocal' ,
130-  'round' , 
131-  'rsqrt' , 
132-  'scale' , 
133-  'sign' , 
134-  'stanh' , 
135-  'sum' , 
136-  'reduce_prod' , 
137-  'max' ,
138-  'min' , 
139-  'mm' , 
140-  'div' , 
141-  'multiply' , 
142-  'add' , 
143-  'logsumexp' , 
144-  'log1p' , 
145-  'erf' ,
146-  'addcmul' , 
147-  'addmm' , 
148-  'clamp' , 
149-  'trace' , 
150-  'kron' , 
151-  'argmax' , 
152-  'argmin' , 
153-  'argsort' ,
154-  'has_inf' , 
155-  'has_nan' , 
156-  'topk' , 
157-  'index_select' , 
158-  'nonzero' , 
159-  'sort' ,
160-  'index_sample' , 
161-  'mean' , 
162-  'std' , 
163-  'var' , 
164-  'elementwise_add' , 
165-  'elementwise_div' ,
166-  'elementwise_floordiv' , 
167-  'elementwise_mod' , 
168-  'elementwise_pow' ,
60+  'exp' , 'tanh' , 'atan' , 'sqrt' , 'rsqrt' , 'abs' , 'ceil' , 'floor' , 'cos' ,
61+  'acos' , 'asin' , 'sin' , 'sinh' , 'cosh' , 'round' , 'reciprocal' , 'square' ,
62+  'rank' , 'matmul' , 'dot' , 'norm' , 'transpose' , 'dist' , 't' , 'cross' ,
63+  'cholesky' , 'bmm' , 'histogram' , 'equal' , 'greater_equal' , 'greater_than' ,
64+  'is_empty' , 'isfinite' , 'less_equal' , 'less_than' , 'logical_and' ,
65+  'logical_not' , 'logical_or' , 'logical_xor' , 'not_equal' , 'reduce_all' ,
66+  'reduce_any' , 'allclose' , 'equal_all' , 'cast' , 'expand' , 'expand_as' ,
67+  'tile' , 'flatten' , 'gather' , 'gather_nd' , 'reshape' , 'reverse' , 'scatter' ,
68+  'scatter_nd_add' , 'scatter_nd' , 'shard_index' , 'slice' , 'split' , 'squeeze' ,
69+  'strided_slice' , 'unique' , 'unique_with_counts' , 'unsqueeze' , 'flip' ,
70+  'unbind' , 'roll' , 'cumsum' , 'increment' , 'log' , 'pow' , 'reciprocal' ,
71+  'round' , 'rsqrt' , 'scale' , 'sign' , 'stanh' , 'sum' , 'reduce_prod' , 'max' ,
72+  'min' , 'mm' , 'div' , 'multiply' , 'add' , 'logsumexp' , 'log1p' , 'erf' ,
73+  'addcmul' , 'addmm' , 'clamp' , 'trace' , 'kron' , 'argmax' , 'argmin' , 'argsort' ,
74+  'has_inf' , 'has_nan' , 'topk' , 'index_select' , 'nonzero' , 'sort' ,
75+  'index_sample' , 'mean' , 'std' , 'var' , 'elementwise_add' , 'elementwise_div' ,
76+  'elementwise_floordiv' , 'elementwise_mod' , 'elementwise_pow' ,
16977 'elementwise_sub' 
17078]
17179
@@ -417,22 +325,36 @@ def __impl__(self, other_var):
417325 # b=-a 
418326 ('__neg__' , _neg_ ),
419327 ('astype' , astype ),
420-  ('__add__' , _binary_creator_ ('__add__' , 'elementwise_add' , False , _scalar_add_ )),
328+  ('__add__' , _binary_creator_ ('__add__' , 'elementwise_add' , False ,
329+  _scalar_add_ )),
421330 # a+b == b+a. Do not need to reverse explicitly 
422-  ('__radd__' , _binary_creator_ ('__radd__' , 'elementwise_add' , False , _scalar_add_ )),
423-  ('__sub__' , _binary_creator_ ('__sub__' , 'elementwise_sub' , False , _scalar_sub_ )),
424-  ('__rsub__' , _binary_creator_ ('__rsub__' , 'elementwise_sub' , True , _scalar_rsub_ )),
425-  ('__mul__' , _binary_creator_ ('__mul__' , 'elementwise_mul' , False , _scalar_mul_ )),
331+  ('__radd__' ,
332+  _binary_creator_ ('__radd__' , 'elementwise_add' , False , _scalar_add_ )),
333+  ('__sub__' , _binary_creator_ ('__sub__' , 'elementwise_sub' , False ,
334+  _scalar_sub_ )),
335+  ('__rsub__' , _binary_creator_ ('__rsub__' , 'elementwise_sub' , True ,
336+  _scalar_rsub_ )),
337+  ('__mul__' , _binary_creator_ ('__mul__' , 'elementwise_mul' , False ,
338+  _scalar_mul_ )),
426339 # a*b == b*a. Do not need to reverse explicitly 
427-  ('__rmul__' , _binary_creator_ ('__rmul__' , 'elementwise_mul' , False , _scalar_mul_ )),
428-  ('__div__' , _binary_creator_ ('__div__' , 'elementwise_div' , False , _scalar_div_ )),
429-  ('__truediv__' , _binary_creator_ ('__truediv__' , 'elementwise_div' , False , _scalar_div_ )),
430-  ('__rdiv__' , _binary_creator_ ('__rdiv__' , 'elementwise_div' , True , None )),
431-  ('__rtruediv__' , _binary_creator_ ('rtruediv__' , 'elementwise_div' , True , None )),
432-  ('__pow__' , _binary_creator_ ('__pow__' , 'elementwise_pow' , False , None )),
433-  ('__rpow__' , _binary_creator_ ('__rpow__' , 'elementwise_pow' , True , None )),
434-  ('__floordiv__' , _binary_creator_ ('__floordiv__' , 'elementwise_floordiv' , False , None )),
435-  ('__mod__' , _binary_creator_ ('__mod__' , 'elementwise_mod' , False , None )),
340+  ('__rmul__' ,
341+  _binary_creator_ ('__rmul__' , 'elementwise_mul' , False , _scalar_mul_ )),
342+  ('__div__' , _binary_creator_ ('__div__' , 'elementwise_div' , False ,
343+  _scalar_div_ )),
344+  ('__truediv__' , _binary_creator_ ('__truediv__' , 'elementwise_div' ,
345+  False , _scalar_div_ )),
346+  ('__rdiv__' , _binary_creator_ ('__rdiv__' , 'elementwise_div' , True ,
347+  None )),
348+  ('__rtruediv__' , _binary_creator_ ('__rtruediv__' , 'elementwise_div' ,
349+  True , None )),
350+  ('__pow__' , _binary_creator_ ('__pow__' , 'elementwise_pow' , False ,
351+  None )),
352+  ('__rpow__' , _binary_creator_ ('__rpow__' , 'elementwise_pow' , True ,
353+  None )),
354+  ('__floordiv__' , _binary_creator_ ('__floordiv__' ,
355+  'elementwise_floordiv' , False , None )),
356+  ('__mod__' , _binary_creator_ ('__mod__' , 'elementwise_mod' , False ,
357+  None )),
436358 # for logical compare 
437359 ('__eq__' , _binary_creator_ ('__eq__' , 'equal' , False , None )),
438360 ('__ne__' , _binary_creator_ ('__ne__' , 'not_equal' , False , None )),
@@ -451,8 +373,8 @@ def __impl__(self, other_var):
451373 else :
452374 import  paddle .tensor 
453375 for  method_name  in  common_methods :
454-  if  hasattr (Variable , method ): continue 
455-  method_impl  =  getattr (paddle .tensor , method , None )
376+  if  hasattr (Variable , method_name ): continue 
377+  method_impl  =  getattr (paddle .tensor , method_name , None )
456378 if  method_impl : setattr (Variable , method_name , method_impl )
457379
458380 _already_patch_variable  =  True 
0 commit comments