Skip to content

Commit e662d1e

Browse files
authored
Update paddle.clamp (#25906)
* Update `paddle.clamp` rename to `paddle.clip` add fast path for dygraph mode remove `out` rename `input` -> `x` update doc sample * Fix leftover `Variable` wording * Indent doc with spaces * Remove `:alias` in docs * Update `enable_imperative` -> `disable_static` * Remove `imperative` also trigger CI * Update tests for better coverage * Rebase to fix `cosine_similarity` * Fix `cosine_similarity` some more
1 parent 1a72a90 commit e662d1e

File tree

8 files changed

+122
-144
lines changed

8 files changed

+122
-144
lines changed

python/paddle/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
from .tensor.math import erf #DEFINE_ALIAS
182182
from .tensor.math import addcmul #DEFINE_ALIAS
183183
from .tensor.math import addmm #DEFINE_ALIAS
184-
from .tensor.math import clamp #DEFINE_ALIAS
184+
from .tensor.math import clip #DEFINE_ALIAS
185185
from .tensor.math import trace #DEFINE_ALIAS
186186
from .tensor.math import kron #DEFINE_ALIAS
187187
from .tensor.math import prod #DEFINE_ALIAS

python/paddle/fluid/layers/nn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12205,8 +12205,6 @@ def logical_not(x, out=None, name=None):
1220512205
@templatedoc()
1220612206
def clip(x, min, max, name=None):
1220712207
"""
12208-
:alias_main: paddle.nn.clip
12209-
:alias: paddle.nn.clip,paddle.nn.clip.clip
1221012208
:old_api: paddle.fluid.layers.clip
1221112209

1221212210
${comment}

python/paddle/fluid/tests/unittests/test_clamp.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

python/paddle/fluid/tests/unittests/test_clip_op.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818
import numpy as np
19+
import paddle
1920
import paddle.fluid as fluid
2021
from paddle.fluid import Program, program_guard
2122
from op_test import OpTest
@@ -109,5 +110,64 @@ def test_dtype():
109110
self.assertRaises(TypeError, test_dtype)
110111

111112

113+
class TestClipAPI(unittest.TestCase):
114+
def test_clip(self):
115+
data_shape = [1, 9, 9, 4]
116+
data = np.random.random(data_shape).astype('float32')
117+
images = fluid.data(name='image', shape=data_shape, dtype='float32')
118+
min = fluid.data(name='min', shape=[1], dtype='float32')
119+
max = fluid.data(name='max', shape=[1], dtype='float32')
120+
121+
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
122+
) else fluid.CPUPlace()
123+
exe = fluid.Executor(place)
124+
125+
out_1 = paddle.clip(images, min=min, max=max)
126+
out_2 = paddle.clip(images, min=0.2, max=0.9)
127+
out_3 = paddle.clip(images, min=0.3)
128+
out_4 = paddle.clip(images, max=0.7)
129+
out_5 = paddle.clip(images, min=min)
130+
out_6 = paddle.clip(images, max=max)
131+
132+
res1, res2, res3, res4, res5, res6 = exe.run(
133+
fluid.default_main_program(),
134+
feed={
135+
"image": data,
136+
"min": np.array([0.2]).astype('float32'),
137+
"max": np.array([0.8]).astype('float32')
138+
},
139+
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
140+
141+
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
142+
self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9)))
143+
self.assertTrue(np.allclose(res3, data.clip(min=0.3)))
144+
self.assertTrue(np.allclose(res4, data.clip(max=0.7)))
145+
self.assertTrue(np.allclose(res5, data.clip(min=0.2)))
146+
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
147+
148+
def test_clip_dygraph(self):
149+
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
150+
) else fluid.CPUPlace()
151+
paddle.disable_static(place)
152+
data_shape = [1, 9, 9, 4]
153+
data = np.random.random(data_shape).astype('float32')
154+
images = paddle.to_variable(data, dtype='float32')
155+
156+
out_1 = paddle.clip(images, min=0.2, max=0.8)
157+
out_2 = paddle.clip(images, min=0.2, max=0.9)
158+
159+
self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
160+
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
161+
162+
def test_errors(self):
163+
paddle.enable_static()
164+
x1 = fluid.data(name='x1', shape=[1], dtype="int16")
165+
x2 = fluid.data(name='x2', shape=[1], dtype="int8")
166+
x3 = fluid.data(name='x3', shape=[1], dtype="float32")
167+
self.assertRaises(TypeError, paddle.clip, x=x1, min=0.2, max=0.8)
168+
self.assertRaises(TypeError, paddle.clip, x=x2, min=0.2, max=0.8)
169+
self.assertRaises(Exception, paddle.clip, x=x3)
170+
171+
112172
if __name__ == '__main__':
113173
unittest.main()

python/paddle/fluid/tests/unittests/test_cosine_similarity_api.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def setUp(self):
2929
if core.is_compiled_with_cuda():
3030
self.places.append(paddle.CUDAPlace(0))
3131

32-
def _get_numpy_out(self, x1, x2, dim=1, eps=1e-8):
33-
w12 = np.sum(x1 * x2, axis=dim)
34-
w1 = np.sum(x1 * x1, axis=dim)
35-
w2 = np.sum(x2 * x2, axis=dim)
32+
def _get_numpy_out(self, x1, x2, axis=1, eps=1e-8):
33+
w12 = np.sum(x1 * x2, axis=axis)
34+
w1 = np.sum(x1 * x1, axis=axis)
35+
w2 = np.sum(x2 * x2, axis=axis)
3636
n12 = np.sqrt(np.clip(w1 * w2, eps * eps, None))
3737
cos_sim = w12 / n12
3838
return cos_sim
@@ -42,22 +42,22 @@ def check_static_result(self, place):
4242

4343
with program_guard(Program(), Program()):
4444
shape = [10, 15]
45-
dim = 1
45+
axis = 1
4646
eps = 1e-8
4747
np.random.seed(0)
4848
np_x1 = np.random.rand(*shape).astype(np.float32)
4949
np_x2 = np.random.rand(*shape).astype(np.float32)
5050

5151
x1 = paddle.data(name="x1", shape=shape)
5252
x2 = paddle.data(name="x2", shape=shape)
53-
result = F.cosine_similarity(x1, x2, dim=dim, eps=eps)
53+
result = F.cosine_similarity(x1, x2, axis=axis, eps=eps)
5454
exe = Executor(place)
5555
fetches = exe.run(default_main_program(),
5656
feed={"x1": np_x1,
5757
"x2": np_x2},
5858
fetch_list=[result])
5959

60-
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps)
60+
np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
6161
self.assertTrue(np.allclose(fetches[0], np_out))
6262

6363
def test_static(self):
@@ -68,33 +68,33 @@ def test_dygraph_1(self):
6868
paddle.disable_static()
6969

7070
shape = [10, 15]
71-
dim = 1
71+
axis = 1
7272
eps = 1e-8
7373
np.random.seed(1)
7474
np_x1 = np.random.rand(*shape).astype(np.float32)
7575
np_x2 = np.random.rand(*shape).astype(np.float32)
76-
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps)
76+
np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
7777

7878
tesnor_x1 = paddle.to_variable(np_x1)
7979
tesnor_x2 = paddle.to_variable(np_x2)
80-
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps)
80+
y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
8181

8282
self.assertTrue(np.allclose(y.numpy(), np_out))
8383

8484
def test_dygraph_2(self):
8585
paddle.disable_static()
8686

8787
shape = [12, 13]
88-
dim = 0
88+
axis = 0
8989
eps = 1e-6
9090
np.random.seed(1)
9191
np_x1 = np.random.rand(*shape).astype(np.float32)
9292
np_x2 = np.random.rand(*shape).astype(np.float32)
93-
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps)
93+
np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
9494

9595
tesnor_x1 = paddle.to_variable(np_x1)
9696
tesnor_x2 = paddle.to_variable(np_x2)
97-
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps)
97+
y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
9898

9999
self.assertTrue(np.allclose(y.numpy(), np_out))
100100

@@ -103,16 +103,16 @@ def test_dygraph_3(self):
103103

104104
shape1 = [10, 12, 10]
105105
shape2 = [10, 1, 10]
106-
dim = 2
106+
axis = 2
107107
eps = 1e-6
108108
np.random.seed(1)
109109
np_x1 = np.random.rand(*shape1).astype(np.float32)
110110
np_x2 = np.random.rand(*shape2).astype(np.float32)
111-
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps)
111+
np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
112112

113113
tesnor_x1 = paddle.to_variable(np_x1)
114114
tesnor_x2 = paddle.to_variable(np_x2)
115-
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps)
115+
y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
116116

117117
self.assertTrue(np.allclose(y.numpy(), np_out))
118118

python/paddle/nn/functional/common.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from ...fluid.layers import squeeze #DEFINE_ALIAS
2929
from ...fluid.layers import unsqueeze #DEFINE_ALIAS
3030
from ...fluid.layers import elementwise_mul #DEFINE_ALIAS
31-
from ...tensor import clamp #DEFINE_ALIAS
32-
from ...tensor import sum #DEFINE_ALIAS
33-
from ...tensor import sqrt #DEFINE_ALIAS
31+
from ...tensor import clip
32+
from ...tensor import sum
33+
from ...tensor import sqrt
3434

3535
#from ...fluid.layers import fc #DEFINE_ALIAS
3636
from ...fluid.layers import pad_constant_like #DEFINE_ALIAS
@@ -635,17 +635,17 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
635635
return out
636636

637637

638-
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
638+
def cosine_similarity(x1, x2, axis=1, eps=1e-8):
639639
"""
640-
Compute cosine similarity between x1 and x2 along dim.
640+
Compute cosine similarity between x1 and x2 along axis.
641641
642642
Parameters:
643643
x1 (Tensor): First input. float32/double.
644644
x2 (Tensor): Second input. float32/double.
645-
dim (int): Dimension of vectors to compute cosine similarity. Default is 1.
645+
axis (int): Dimension of vectors to compute cosine similarity. Default is 1.
646646
eps(float): Small value to avoid division by zero. Default is 1e-8.
647647
648-
Returns: a Tensor representing cosine similarity between x1 and x2 along dim.
648+
Returns: a Tensor representing cosine similarity between x1 and x2 along axis.
649649
Return Type: Tensor
650650
651651
Examples:
@@ -659,7 +659,7 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
659659
[0.9098952 0.15715368 0.8671125 0.3156102 ]
660660
[0.4427798 0.54136837 0.5276275 0.32394758]
661661
[0.3769419 0.8535014 0.48041078 0.9256797 ]]
662-
dim = 1
662+
axis = 1
663663
eps = 1e-8
664664
Out: [0.5275037 0.8368967 0.75037485 0.9245899]
665665
@@ -675,14 +675,14 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
675675
x2 = np.random.rand(2,3)
676676
x1 = paddle.to_tensor(x1)
677677
x2 = paddle.to_tensor(x2)
678-
result = paddle.nn.functional.cosine_similarity(x1, x2, dim=0)
678+
result = paddle.nn.functional.cosine_similarity(x1, x2, axis=0)
679679
print(result.numpy())
680680
# [0.99806249 0.9817672 0.94987036]
681681
682682
"""
683-
w12 = sum(elementwise_mul(x1, x2), axis=dim)
684-
w1 = sum(elementwise_mul(x1, x1), axis=dim)
685-
w2 = sum(elementwise_mul(x2, x2), axis=dim)
686-
n12 = sqrt(clamp(w1 * w2, min=eps * eps))
683+
w12 = sum(elementwise_mul(x1, x2), axis=axis)
684+
w1 = sum(elementwise_mul(x1, x1), axis=axis)
685+
w2 = sum(elementwise_mul(x2, x2), axis=axis)
686+
n12 = sqrt(clip(w1 * w2, min=eps * eps))
687687
cos_sim = w12 / n12
688688
return cos_sim

python/paddle/tensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@
154154
from .math import erf #DEFINE_ALIAS
155155
from .math import addcmul #DEFINE_ALIAS
156156
from .math import addmm #DEFINE_ALIAS
157-
from .math import clamp #DEFINE_ALIAS
157+
from .math import clip #DEFINE_ALIAS
158158
from .math import trace #DEFINE_ALIAS
159159
from .math import kron #DEFINE_ALIAS
160160
from .math import prod #DEFINE_ALIAS

0 commit comments

Comments
 (0)