Skip to content
19 changes: 10 additions & 9 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -931,17 +931,17 @@ Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) {
auto min_reshape = min;
auto max_reshape = max;

if (x.shape().size() == 0) {
min_reshape = reshape<T>(min_reshape, {});
max_reshape = reshape<T>(max_reshape, {});
}

if (has_dynamic_shape(x.shape())) {
min_reshape = backend::expand<T>(min_reshape, shape64<T>(x));
max_reshape = backend::expand<T>(max_reshape, shape64<T>(x));
} else {
min_reshape = expand<T>(min_reshape, x.shape());
max_reshape = expand<T>(max_reshape, x.shape());
if (x.shape().size() == 0) {
min_reshape = reshape<T>(min_reshape, {});
max_reshape = reshape<T>(max_reshape, {});
} else {
min_reshape = expand<T>(min_reshape, x.shape());
max_reshape = expand<T>(max_reshape, x.shape());
}
}
if (min_reshape.dtype() != x.dtype()) {
min_reshape = cast<T>(min_reshape, x.dtype());
Expand All @@ -950,8 +950,9 @@ Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) {
if (max_reshape.dtype() != x.dtype()) {
max_reshape = cast<T>(max_reshape, x.dtype());
}

auto ans = maximum<T>(minimum<T>(x, max_reshape), min_reshape);
auto ans = where<T>(x <= max_reshape,
where<T>(x >= min_reshape, x, min_reshape),
max_reshape);
return ans;
}

Expand Down
33 changes: 31 additions & 2 deletions test/legacy_test/test_clip_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
else:
max_v = self.attrs['max']

input = np.random.random(self.shape).astype(self.dtype)
input = self.generate_input()
input[np.abs(input - min_v) < self.max_relative_error] = 0.5
input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input
Expand All @@ -67,7 +67,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)
paddle.disable_static()

def initTestCase(self):
Expand All @@ -78,6 +78,9 @@ def initTestCase(self):
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.1]).astype(self.dtype)

def generate_input(self):
return np.random.random(self.shape).astype(self.dtype)


class TestCase1(TestClipOp):
def initTestCase(self):
Expand Down Expand Up @@ -121,6 +124,19 @@ def initTestCase(self):
self.min = 0.5


class TestCase6(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 16)
self.max = 1.0
self.min = 0.5

def generate_input(self):
return np.random.choice([self.min, self.max], self.shape).astype(
self.dtype
)


class TestFP16Case1(TestClipOp):
def initTestCase(self):
self.dtype = np.float16
Expand Down Expand Up @@ -163,6 +179,19 @@ def initTestCase(self):
self.min = 0.5


class TestFP16Case6(TestClipOp):
def initTestCase(self):
self.dtype = np.float16
self.shape = (4, 8, 16)
self.max = 1.0
self.min = 0.5

def generate_input(self):
return np.random.choice([self.min, self.max], self.shape).astype(
self.dtype
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down
Loading