-
Couldn't load subscription status.
- Fork 5.9k
[API compatibility] add dtype conversion method #74416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fa1c193 e0a5bec 9d315b9 393fdde 9fe87a8 1974480 1120f28 d850a5f a5972ee File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -65,6 +65,36 @@ | |
| _already_patch_eager_tensor = False | ||
| | ||
| | ||
| _supported_dtype_conversions = { | ||
| # float | ||
| 'float16': 'float16', | ||
| 'half': 'float16', | ||
| 'bfloat16': 'bfloat16', | ||
| 'float32': 'float32', | ||
| 'float': 'float32', | ||
| 'float64': 'float64', | ||
| 'double': 'float64', | ||
| # int | ||
| 'int8': 'int8', | ||
| 'char': 'int8', | ||
| # We handle uint8 conversion separately | ||
| # 'uint8': 'uint8', | ||
| # 'byte': 'uint8', | ||
| 'int16': 'int16', | ||
| 'short': 'int16', | ||
| 'int32': 'int32', | ||
| 'int': 'int32', | ||
| 'int64': 'int64', | ||
| 'long': 'int64', | ||
| # other | ||
| 'bool': 'bool', | ||
| 'complex64': 'complex64', | ||
| 'complex128': 'complex128', | ||
| 'cfloat': 'complex64', | ||
| 'cdouble': 'complex128', | ||
| } | ||
| | ||
| | ||
| def monkey_patch_math_tensor(): | ||
| """ | ||
| Similar to monkey_patch_variable. | ||
| | @@ -104,6 +134,44 @@ def astype(self: Tensor, dtype: DTypeLike) -> Tensor: | |
| | ||
| return _C_ops.cast(self, dtype) | ||
| | ||
| def byte(self: Tensor) -> Tensor: | ||
| # since paddle don't support float to uint8, so we need to convert it to int8 first | ||
| if self.is_floating_point(): | ||
| tensor = astype(self, 'int8') | ||
| return astype(tensor, 'uint8') | ||
| elif self.is_complex(): | ||
| real = astype(self.real(), 'int8') | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 | ||
| return astype(real, 'uint8') | ||
| else: | ||
| return astype(self, 'uint8') | ||
| | ||
| def _create_dtype_conversion_methods(): | ||
| """ | ||
| Batch create all data type conversion methods | ||
| """ | ||
| methods = [] | ||
| | ||
| for method_name, target_dtype in _supported_dtype_conversions.items(): | ||
| | ||
| def make_conversion_method(dtype): | ||
| def conversion_method(self: Tensor) -> Tensor: | ||
| return astype(self, dtype) | ||
| | ||
| return conversion_method | ||
| | ||
| method_impl = make_conversion_method(target_dtype) | ||
| method_impl.__name__ = method_name | ||
| method_impl.__doc__ = f""" | ||
| Cast a Tensor to {target_dtype} data type if it differs from the current dtype; | ||
| otherwise, return the original Tensor. | ||
| Returns: | ||
| Tensor: a new Tensor with {target_dtype} dtype | ||
| """ | ||
| | ||
| methods.append((method_name, method_impl)) | ||
| | ||
| return methods | ||
| | ||
| def _scalar_elementwise_op_( | ||
| var: Tensor, scale: float, bias: float | ||
| ) -> Tensor: | ||
| | @@ -225,6 +293,8 @@ def _mT_(var: Tensor) -> Tensor: | |
| ('__len__', _len_), | ||
| ('__index__', _index_), | ||
| ('astype', astype), | ||
| ('byte', byte), | ||
| ('uint8', byte), | ||
| ('dim', dim), | ||
| ('ndimension', ndimension), | ||
| ('ndim', _ndim), | ||
| | @@ -235,6 +305,9 @@ def _mT_(var: Tensor) -> Tensor: | |
| ('__array_ufunc__', None), | ||
| ] | ||
| | ||
| dtype_conversion_methods = _create_dtype_conversion_methods() | ||
| eager_methods.extend(dtype_conversion_methods) | ||
| | ||
| eager_cpp_level_patch = [ | ||
| "__add__", | ||
| "__radd__", | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -37,6 +37,35 @@ | |
| DataType.INT64, | ||
| ] | ||
| | ||
| _supported_dtype_conversions = { | ||
| # float | ||
| 'float16': 'float16', | ||
| 'half': 'float16', | ||
| 'bfloat16': 'bfloat16', | ||
| 'float32': 'float32', | ||
| 'float': 'float32', | ||
| 'float64': 'float64', | ||
| 'double': 'float64', | ||
| # int | ||
| 'int8': 'int8', | ||
| 'char': 'int8', | ||
| # We handle uint8 conversion separately | ||
| # 'uint8': 'uint8', | ||
| # 'byte': 'uint8', | ||
| 'int16': 'int16', | ||
| 'short': 'int16', | ||
| 'int32': 'int32', | ||
| 'int': 'int32', | ||
| 'int64': 'int64', | ||
| 'long': 'int64', | ||
| # other | ||
| 'bool': 'bool', | ||
| 'complex64': 'complex64', | ||
| 'complex128': 'complex128', | ||
| 'cfloat': 'complex64', | ||
| 'cdouble': 'complex128', | ||
| } | ||
| | ||
| SUPPORT_PROMOTION_OPS = [ | ||
| "__add__", | ||
| "__radd__", | ||
| | @@ -370,6 +399,41 @@ def astype(self, dtype): | |
| | ||
| return _C_ops.cast(self, dtype) | ||
| | ||
| def byte(self): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 | ||
| # since paddle don't support float to uint8, so we need to convert it to int8 first | ||
| if self.is_floating_point(): | ||
| tensor = astype(self, 'int8') | ||
| return astype(tensor, 'uint8') | ||
| elif self.is_complex(): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 复数情况下在PaConvert里和torch对比下,看torch是不是也是这么处理的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已在PaConvert新增test: PaddlePaddle/PaConvert#617 ,并在本地测试通过 | ||
| real = astype(self.real(), 'int8') | ||
| return astype(real, 'uint8') | ||
| else: | ||
| return astype(self, 'uint8') | ||
| | ||
| def _create_dtype_conversion_methods(): | ||
| """ | ||
| Batch create all data type conversion methods | ||
| """ | ||
| methods = [] | ||
| for method_name, target_dtype in _supported_dtype_conversions.items(): | ||
| | ||
| def make_conversion_method(dtype): | ||
| def conversion_method(self): | ||
| return astype(self, dtype) | ||
| | ||
| return conversion_method | ||
| | ||
| method_impl = make_conversion_method(target_dtype) | ||
| method_impl.__name__ = method_name | ||
| method_impl.__doc__ = f""" | ||
| Cast a Value to {target_dtype} data type if it differs from the current dtype; | ||
| otherwise, return the original Value. | ||
| Returns: | ||
| Value: a new Value with {target_dtype} dtype | ||
| """ | ||
| methods.append((method_name, method_impl)) | ||
| return methods | ||
| | ||
| def _scalar_add_(var, value): | ||
| return paddle.scale(var, 1.0, value) | ||
| | ||
| | @@ -1109,6 +1173,8 @@ def register_hook(self, hook): | |
| ('ndimension', ndimension), | ||
| ('ndim', _ndim), | ||
| ('astype', astype), | ||
| ('byte', byte), | ||
| ('uint8', byte), | ||
| ('size', _size_), | ||
| ('T', _T_), | ||
| ('mT', _mT_), | ||
| | @@ -1253,6 +1319,8 @@ def register_hook(self, hook): | |
| ('__bool__', _bool_), | ||
| ('__complex__', _complex_), | ||
| ] | ||
| dtype_conversion_methods = _create_dtype_conversion_methods() | ||
| value_methods.extend(dtype_conversion_methods) | ||
| | ||
| global _already_patch_value | ||
| if not _already_patch_value: | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在c++ cast kernel里能新增注册类型吗,c++里没注册上uint8的类型吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看到cast支持转换到uint8,默认会将小于0的float数据转换为0,直接使用cast转换float 到 uint8的结果与torch.Tensor.byte不一致,所以用了两次转换暂时保证paddle.Tensor.byte与Torch一致。