Skip to content

Commit 4c474a9

Browse files
apaszkesoumith
authored andcommitted
Improve prodall CUDA test
1 parent 7ea6ae5 commit 4c474a9

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

test/test_cuda.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def small_2d_scaled(t, scale=10):
5959
return make_tensor(t, S, S).mul(scale)
6060

6161

62+
def small_2d_oneish(t):
63+
if is_floating(t):
64+
return make_tensor(t, S, S).clamp(min=0.99, max=1.01)
65+
else:
66+
return t(S, S).fill_(1)
67+
68+
6269
def small_3d(t):
6370
return make_tensor(t, S, S, S)
6471

@@ -206,7 +213,7 @@ def tmp(t):
206213
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
207214
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
208215
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
209-
('prod', small_3d, lambda t: [],),
216+
('prod', small_2d_oneish, lambda t: [],),
210217
('prod', small_3d, lambda t: [1], 'dim'),
211218
('sum', small_2d, lambda t: [],),
212219
('sum', small_3d, lambda t: [1], 'dim'),

0 commit comments

Comments
 (0)