Skip to content
73 changes: 73 additions & 0 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,36 @@
_already_patch_eager_tensor = False


_supported_dtype_conversions = {
# float
'float16': 'float16',
'half': 'float16',
'bfloat16': 'bfloat16',
'float32': 'float32',
'float': 'float32',
'float64': 'float64',
'double': 'float64',
# int
'int8': 'int8',
'char': 'int8',
# We handle uint8 conversion separately
# 'uint8': 'uint8',
# 'byte': 'uint8',
'int16': 'int16',
'short': 'int16',
'int32': 'int32',
'int': 'int32',
'int64': 'int64',
'long': 'int64',
# other
'bool': 'bool',
'complex64': 'complex64',
'complex128': 'complex128',
'cfloat': 'complex64',
'cdouble': 'complex128',
}


def monkey_patch_math_tensor():
"""
Similar to monkey_patch_variable.
Expand Down Expand Up @@ -104,6 +134,44 @@ def astype(self: Tensor, dtype: DTypeLike) -> Tensor:

return _C_ops.cast(self, dtype)

def byte(self: Tensor) -> Tensor:
# since paddle don't support float to uint8, so we need to convert it to int8 first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在c++ cast kernel里能新增注册类型吗,c++里没注册上uint8的类型吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看到cast支持转换到uint8,默认会将小于0的float数据转换为0,直接使用cast转换float 到 uint8的结果与torch.Tensor.byte不一致,所以用了两次转换暂时保证paddle.Tensor.byte与Torch一致。

# 直接转换 >>> paddle.cast(paddle.to_tensor([-1, -2, -3, 12, 2, 3], dtype='float32'), 'uint8') Tensor(shape=[6], dtype=uint8, place=Place(gpu:0), stop_gradient=True, [0 , 0 , 0 , 12, 2 , 3 ]) # 先转成int, 再转成uint8 >>> paddle.cast(paddle.to_tensor([-1., -2., -3., 12., 2., 3.], dtype='int32'), 'uint8') Tensor(shape=[6], dtype=uint8, place=Place(gpu:0), stop_gradient=True, [255, 254, 253, 12 , 2 , 3 ]) # Torch >>> torch.Tensor([-1., -2., -3., 12., 2., 3.]).byte() tensor([255, 254, 253, 12, 2, 3], dtype=torch.uint8) 
if self.is_floating_point():
tensor = astype(self, 'int8')
return astype(tensor, 'uint8')
elif self.is_complex():
real = astype(self.real(), 'int8')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

return astype(real, 'uint8')
else:
return astype(self, 'uint8')

def _create_dtype_conversion_methods():
"""
Batch create all data type conversion methods
"""
methods = []

for method_name, target_dtype in _supported_dtype_conversions.items():

def make_conversion_method(dtype):
def conversion_method(self: Tensor) -> Tensor:
return astype(self, dtype)

return conversion_method

method_impl = make_conversion_method(target_dtype)
method_impl.__name__ = method_name
method_impl.__doc__ = f"""
Cast a Tensor to {target_dtype} data type if it differs from the current dtype;
otherwise, return the original Tensor.
Returns:
Tensor: a new Tensor with {target_dtype} dtype
"""

methods.append((method_name, method_impl))

return methods

def _scalar_elementwise_op_(
var: Tensor, scale: float, bias: float
) -> Tensor:
Expand Down Expand Up @@ -225,6 +293,8 @@ def _mT_(var: Tensor) -> Tensor:
('__len__', _len_),
('__index__', _index_),
('astype', astype),
('byte', byte),
('uint8', byte),
('dim', dim),
('ndimension', ndimension),
('ndim', _ndim),
Expand All @@ -235,6 +305,9 @@ def _mT_(var: Tensor) -> Tensor:
('__array_ufunc__', None),
]

dtype_conversion_methods = _create_dtype_conversion_methods()
eager_methods.extend(dtype_conversion_methods)

eager_cpp_level_patch = [
"__add__",
"__radd__",
Expand Down
68 changes: 68 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@
DataType.INT64,
]

_supported_dtype_conversions = {
# float
'float16': 'float16',
'half': 'float16',
'bfloat16': 'bfloat16',
'float32': 'float32',
'float': 'float32',
'float64': 'float64',
'double': 'float64',
# int
'int8': 'int8',
'char': 'int8',
# We handle uint8 conversion separately
# 'uint8': 'uint8',
# 'byte': 'uint8',
'int16': 'int16',
'short': 'int16',
'int32': 'int32',
'int': 'int32',
'int64': 'int64',
'long': 'int64',
# other
'bool': 'bool',
'complex64': 'complex64',
'complex128': 'complex128',
'cfloat': 'complex64',
'cdouble': 'complex128',
}

SUPPORT_PROMOTION_OPS = [
"__add__",
"__radd__",
Expand Down Expand Up @@ -370,6 +399,41 @@ def astype(self, dtype):

return _C_ops.cast(self, dtype)

def byte(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

# since paddle don't support float to uint8, so we need to convert it to int8 first
if self.is_floating_point():
tensor = astype(self, 'int8')
return astype(tensor, 'uint8')
elif self.is_complex():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

复数情况下在PaConvert里和torch对比下,看torch是不是也是这么处理的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在PaConvert新增test: PaddlePaddle/PaConvert#617 ,并在本地测试通过

def test_complex64_to_byte(): # for complex --> uint8, we used the new test data pytorch_code = textwrap.dedent( """ import torch src = torch.tensor([0.+3.5j, -1+4.2j, 2.34-5.2j, -3.45+7.9j, -0.34-8.2j, 0.23+9.2j, 1.+1.j, 2.+0.5j, 3.-1.j,], dtype=torch.complex64) result = src.byte() """ ) obj.run(pytorch_code, ["result"]) def test_complex128_to_byte(): # for complex --> uint8, we used the new test data pytorch_code = textwrap.dedent( """ import torch src = torch.tensor([0.+3.5j, -1+4.2j, 2.34-5.2j, -3.45+7.9j, -0.34-8.2j, 0.23+9.2j, 1.+1.j, 2.+0.5j, 3.-1.j,], dtype=torch.complex128) result = src.byte() """ ) obj.run(pytorch_code, ["result"]) 
real = astype(self.real(), 'int8')
return astype(real, 'uint8')
else:
return astype(self, 'uint8')

def _create_dtype_conversion_methods():
"""
Batch create all data type conversion methods
"""
methods = []
for method_name, target_dtype in _supported_dtype_conversions.items():

def make_conversion_method(dtype):
def conversion_method(self):
return astype(self, dtype)

return conversion_method

method_impl = make_conversion_method(target_dtype)
method_impl.__name__ = method_name
method_impl.__doc__ = f"""
Cast a Value to {target_dtype} data type if it differs from the current dtype;
otherwise, return the original Value.
Returns:
Value: a new Value with {target_dtype} dtype
"""
methods.append((method_name, method_impl))
return methods

def _scalar_add_(var, value):
return paddle.scale(var, 1.0, value)

Expand Down Expand Up @@ -1109,6 +1173,8 @@ def register_hook(self, hook):
('ndimension', ndimension),
('ndim', _ndim),
('astype', astype),
('byte', byte),
('uint8', byte),
('size', _size_),
('T', _T_),
('mT', _mT_),
Expand Down Expand Up @@ -1253,6 +1319,8 @@ def register_hook(self, hook):
('__bool__', _bool_),
('__complex__', _complex_),
]
dtype_conversion_methods = _create_dtype_conversion_methods()
value_methods.extend(dtype_conversion_methods)

global _already_patch_value
if not _already_patch_value:
Expand Down
Loading