There was an error while loading. Please reload this page.
1 parent 7ea6ae5 commit 4c474a9Copy full SHA for 4c474a9
test/test_cuda.py
@@ -59,6 +59,13 @@ def small_2d_scaled(t, scale=10):
59
return make_tensor(t, S, S).mul(scale)
60
61
62
+def small_2d_oneish(t):
63
+ if is_floating(t):
64
+ return make_tensor(t, S, S).clamp(min=0.99, max=1.01)
65
+ else:
66
+ return t(S, S).fill_(1)
67
+
68
69
def small_3d(t):
70
return make_tensor(t, S, S, S)
71
@@ -206,7 +213,7 @@ def tmp(t):
206
213
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
207
214
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
208
215
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
209
- ('prod', small_3d, lambda t: [],),
216
+ ('prod', small_2d_oneish, lambda t: [],),
210
217
('prod', small_3d, lambda t: [1], 'dim'),
211
218
('sum', small_2d, lambda t: [],),
212
219
('sum', small_3d, lambda t: [1], 'dim'),
0 commit comments