Skip to content

Commit 4508c84

Browse files
authored
[proto][tests] Fix missing extra_dims in cxcywh (#6906)
1 parent cb4413a commit 4508c84

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/prototype_common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ def fn(shape, dtype, device):
373373
h = randint_with_tensor_bounds(1, height - y)
374374
parts = (x, y, w, h)
375375
else: # format == features.BoundingBoxFormat.CXCYWH:
376-
cx = torch.randint(1, width - 1, ())
377-
cy = torch.randint(1, height - 1, ())
376+
cx = torch.randint(1, width - 1, extra_dims)
377+
cy = torch.randint(1, height - 1, extra_dims)
378378
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
379379
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
380380
parts = (cx, cy, w, h)

0 commit comments

Comments
 (0)