Skip to content
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
import paddle.dataset
import paddle.batch
batch = batch.batch
from .fluid import monkey_patch_variable
from .fluid.dygraph import monkey_patch_math_varbase
monkey_patch_variable()
monkey_patch_math_varbase()
import paddle.framework
from .framework import VarBase as Tensor
from .framework import ComplexVariable as ComplexTensor
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/dygraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
from . import amp
from .amp import *

from .math_op_patch import monkey_patch_math_varbase

__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
Expand Down
185 changes: 126 additions & 59 deletions python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .. import core
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator
from ..layers.layer_function_generator import OpProtoHolder
from ..layers import common_methods
from . import to_variable, no_grad

import numpy as np
Expand All @@ -30,6 +31,8 @@
core.VarDesc.VarType.INT64,
]

_already_patch_varbase = False


def monkey_patch_math_varbase():
"""
Expand Down Expand Up @@ -140,25 +143,30 @@ def _index_(var):
else:
return int(var.numpy().flatten()[0])

def _scalar_elementwise_add_(var, value):
@property
def _ndim_(var):
return len(var.shape)

def _scalar_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)

def _scalar_elementwise_sub_(var, value):
def _scalar_sub_(var, value):
return _scalar_elementwise_op_(var, 1.0, -value)

def _scalar_elementwise_rsub_(var, value):
def _scalar_rsub_(var, value):
return _scalar_elementwise_op_(var, -1.0, value)

def _scalar_elementwise_mul_(var, value):
def _scalar_mul_(var, value):
return _scalar_elementwise_op_(var, value, 0.0)

def _scalar_elementwise_div_(var, value):
def _scalar_div_(var, value):
return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

def _elemwise_method_creator_(method_name,
op_type,
reverse=False,
scalar_method=None):
# for binary operator such as elementwise, compare
def _binary_creator_(method_name,
op_type,
reverse=False,
scalar_method=None):
def __impl__(self, other_var):
# FIXME(zjl): elementwise_div between integers cannot be converted to scale,
# which may lose accuracy. This is a hot fix for release 1.6.
Expand Down Expand Up @@ -200,60 +208,119 @@ def __impl__(self, other_var):
__impl__.__doc__ = """
{0}
Args:
self(Variable): left hand variable
other_var(Variable|float|int): right hand variable
self(Tensor): left hand Tensor
other_var(Tensor|float|int): right hand Tensor

Returns:
Variable
Tensor
""".format(comment)
__impl__.__name__ = method_name
return __impl__

# inject methods
for method_name, op_type, reverse, scalar_method in (
("__add__", "elementwise_add", False, _scalar_elementwise_add_),
# a+b == b+a. Do not need to reverse explicitly
("__radd__", "elementwise_add", False, _scalar_elementwise_add_),
("__sub__", "elementwise_sub", False, _scalar_elementwise_sub_),
("__rsub__", "elementwise_sub", True, _scalar_elementwise_rsub_),
("__mul__", "elementwise_mul", False, _scalar_elementwise_mul_),
# a*b == b*a. Do not need to reverse explicitly
("__rmul__", "elementwise_mul", False, _scalar_elementwise_mul_),
("__div__", "elementwise_div", False, _scalar_elementwise_div_),
("__truediv__", "elementwise_div", False, _scalar_elementwise_div_),
("__rdiv__", "elementwise_div", True, None),
("__rtruediv__", "elementwise_div", True, None),
("__pow__", "elementwise_pow", False, None),
("__rpow__", "elementwise_pow", True, None),
("__floordiv__", "elementwise_floordiv", False, None),
("__mod__", "elementwise_mod", False, None),
# for logical compare
("__eq__", "equal", False, None),
("__ne__", "not_equal", False, None),
("__lt__", "less_than", False, None),
("__le__", "less_equal", False, None),
("__gt__", "greater_than", False, None),
("__ge__", "greater_equal", False, None)):

setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))

# b = -a
core.VarBase.__neg__ = _neg_
core.VarBase.__float__ = _float_
core.VarBase.__long__ = _long_
core.VarBase.__int__ = _int_
core.VarBase.__len__ = _len_
core.VarBase.__index__ = _index_
core.VarBase.astype = astype
"""
When code is written like this
y = np.pi * var
ndarray.__mul__(self, var) is called, var will be traced as an array(by using __len__, __getitem__), which is not right.
when var.__array_ufunc__ is set to None, var.__rmul__(self, np) will be called.
# Todo(zhouwei): implement dygraph template to adapt to any function, receive('op_type', 'arg_template')
# Such as _method_creator_('addmm', 'x, y, alpha=1.0, beta=1.0, name=None'). It can reduce call time.
def _method_creator_(op_type, arg_template=None):
def __impl__(self):
op = getattr(core.ops, op_type)
return op(self)

The details can be seen bellow:
https://docs.scipy.org/doc/numpy-1.13.0/neps/ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
core.VarBase.__array_ufunc__ = None
__impl__.__doc__ = """

See paddle.{}""".format(op_type)
__impl__.__name__ = op_type

return __impl__

varbase_methods = [
# Type1: From custom fun or lambda
## b=-a
('__neg__', _neg_),
('__float__', _float_),
('__long__', _long_),
('__int__', _int_),
('__len__', _len_),
('__index__', _index_),
('astype', astype),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
('size', lambda x: x.shape),
# Type2: From Template that create core.ops automatically. It's recommended.
('__add__',
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
## a+b == b+a. Do not need to reverse explicitly
('__radd__',
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
_scalar_sub_)),
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
_scalar_rsub_)),
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
_scalar_mul_)),
## a*b == b*a. Do not need to reverse explicitly
('__rmul__',
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
_scalar_div_)),
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
False, _scalar_div_)),
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
None)),
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
None)),
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
None)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
None)),
('__floordiv__', _binary_creator_('__floordiv__',
'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
None)),
## for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
('__array_ufunc__', None),
('sigmoid', _method_creator_('sigmoid', 'name=None')),
('logsigmoid', _method_creator_('logsigmoid', 'name=None')),
('exp', _method_creator_('exp', 'name=None')),
('tanh', _method_creator_('tanh', 'name=None')),
('atan', _method_creator_('atan', 'name=None')),
('tanh_shrink', _method_creator_('tanh_shrink', 'name=None')),
('sqrt', _method_creator_('sqrt', 'name=None')),
('rsqrt', _method_creator_('rsqrt', 'name=None')),
('abs', _method_creator_('abs', 'name=None')),
('ceil', _method_creator_('ceil', 'name=None')),
('floor', _method_creator_('floor', 'name=None')),
('cos', _method_creator_('cos', 'name=None')),
('acos', _method_creator_('acos', 'name=None')),
('asin', _method_creator_('asin', 'name=None')),
('sin', _method_creator_('sin', 'name=None')),
('sinh', _method_creator_('sinh', 'name=None')),
('cosh', _method_creator_('cosh', 'name=None')),
('round', _method_creator_('round', 'name=None')),
('reciprocal', _method_creator_('reciprocal', 'name=None')),
('square', _method_creator_('square', 'name=None')),
('softplus', _method_creator_('softplus', 'name=None')),
('softsign', _method_creator_('softsign', 'name=None')),
# Type3: Form module 'paddle.tensor' defaultly.
# It's not a goodway, because it will increase call time.
]

global _already_patch_varbase
if not _already_patch_varbase:
for method in varbase_methods:
method_name = method[0]
method_impl = method[1]
setattr(core.VarBase, method_name, method_impl)
else:
import paddle.tensor
for method_name in common_methods:
if hasattr(core.VarBase, method_name): continue
method_impl = getattr(paddle.tensor, method_name, None)
if method_impl: setattr(core.VarBase, method_name, method_impl)

_already_patch_varbase = True
Loading