Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def conversion_method(self: Tensor) -> Tensor:

return methods

def type_as(self: Tensor, other: Tensor) -> Tensor:
return self.astype(other.dtype)

def _scalar_elementwise_op_(
var: Tensor, scale: float, bias: float
) -> Tensor:
Expand Down Expand Up @@ -295,6 +298,7 @@ def _mT_(var: Tensor) -> Tensor:
('astype', astype),
('byte', byte),
('uint8', byte),
('type_as', type_as),
('dim', dim),
('ndimension', ndimension),
('ndim', _ndim),
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def astype(self, dtype):
out.stop_gradient = self.stop_gradient
return out

def type_as(self, other):
return self.astype(other.dtype)

@static_only
def append(self, var):
"""
Expand Down Expand Up @@ -799,6 +802,7 @@ def to_dense(var):
('__neg__', _neg_),
('__abs__', _abs_),
('astype', astype),
('type_as', type_as),
('cpu', cpu),
('cuda', cuda),
('place', place),
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def conversion_method(self):
methods.append((method_name, method_impl))
return methods

def type_as(self, other):
return self.astype(other.dtype)

def _scalar_add_(var, value):
return paddle.scale(var, 1.0, value)

Expand Down Expand Up @@ -1175,6 +1178,7 @@ def register_hook(self, hook):
('astype', astype),
('byte', byte),
('uint8', byte),
('type_as', type_as),
('size', _size_),
('T', _T_),
('mT', _mT_),
Expand Down
153 changes: 153 additions & 0 deletions test/legacy_test/test_type_as.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import base


def api_warpprt(x, y):
return x.type_as(y)


class TestTypeAsBase(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个单测里没看到调用Tensor.type_as呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

笔误,已修改~

def setUp(self):
self.input_dtype_1 = "float32"
self.input_dtype_2 = "float16"
self.input_shape = (2, 3)

self.input_np_1 = self.generate_data(
self.input_dtype_1, self.input_shape
)
self.input_np_2 = self.generate_data(
self.input_dtype_2, self.input_shape
)

self.input_shape_1 = self.input_np_1.shape
self.input_shape_2 = self.input_np_2.shape

self.op_static = api_warpprt
self.op_dygraph = api_warpprt
self.places = [None, paddle.CPUPlace()]

def generate_data(self, dtype, shape):
if "int" in dtype:
data = np.arange(1, np.prod(shape) + 1).reshape(shape)
else:
data = np.arange(1, np.prod(shape) + 1, dtype='float32').reshape(
shape
)
return data.astype(dtype)

def check_static_result(self, place):
paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
input_name_1 = 'input_1'
input_name_2 = 'input_2'
input_var_1 = paddle.static.data(
name=input_name_1,
shape=self.input_shape_1,
dtype=self.input_dtype_1,
)
input_var_2 = paddle.static.data(
name=input_name_2,
shape=self.input_shape_2,
dtype=self.input_dtype_2,
)
res = self.op_static(input_var_1, input_var_2)
exe = base.Executor(place)
fetches = exe.run(
main_prog,
feed={
input_name_1: self.input_np_1,
input_name_2: self.input_np_2,
},
fetch_list=[res],
)
self.assertEqual(fetches[0].dtype, np.dtype(self.input_dtype_2))

def test_static(self):
for place in self.places:
self.check_static_result(place=place)

def check_dygraph_result(self, place):
with base.dygraph.guard(place):
input_1 = paddle.to_tensor(self.input_np_1)
input_2 = paddle.to_tensor(self.input_np_2)
result = self.op_dygraph(input_1, input_2)
self.assertEqual(result.dtype, input_2.dtype)

def test_dygraph(self):
for place in self.places:
self.check_dygraph_result(place=place)


class TestTypeAsFloat32ToFloat16(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "float32"
self.input_dtype_2 = "float16"
super().setUp()


class TestTypeAsFloat64ToFloat32(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "float64"
self.input_dtype_2 = "float32"
super().setUp()


class TestTypeAsInt32ToInt64(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "int32"
self.input_dtype_2 = "int64"
super().setUp()


class TestTypeAsInt32ToFloat32(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "int32"
self.input_dtype_2 = "float32"
super().setUp()


class TestTypeAsFloat32ToInt64(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "float32"
self.input_dtype_2 = "int64"
super().setUp()


class TestTypeAsInt8ToFloat64(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "int8"
self.input_dtype_2 = "float64"
self.input_shape = (4, 2)
super().setUp()


class TestTypeAsUInt8ToInt32(TestTypeAsBase):
def setUp(self):
self.input_dtype_1 = "uint8"
self.input_dtype_2 = "int32"
self.input_shape = (3, 3)
super().setUp()


if __name__ == "__main__":
unittest.main()