Skip to content

Commit 79ca506

Browse files
authored
[proto] Optimized functional pad op for bboxes + tests (#6890)
* [proto] Speed-up crop on bboxes and tests * Fix linter * Update _geometry.py * Fixed device issue * Revert changes in test/prototype_transforms_kernel_infos.py * Fixed failing correctness tests * [proto] Optimized functional pad op for bboxes + tests * Renamed copy-pasted variable name * Code update * Fixes according to the review
1 parent d8cec34 commit 79ca506

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from torch.utils._pytree import tree_map
2727
from torchvision.prototype import features
28-
from torchvision.transforms.functional_tensor import _max_value as get_max_value
28+
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
2929

3030
__all__ = ["KernelInfo", "KERNEL_INFOS"]
3131

@@ -1078,6 +1078,38 @@ def sample_inputs_pad_video():
10781078
yield ArgsKwargs(video_loader, padding=[1])
10791079

10801080

1081+
def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode):
1082+
1083+
left, right, top, bottom = _parse_pad_padding(padding)
1084+
1085+
affine_matrix = np.array(
1086+
[
1087+
[1, 0, left],
1088+
[0, 1, top],
1089+
],
1090+
dtype="float32",
1091+
)
1092+
1093+
height = spatial_size[0] + top + bottom
1094+
width = spatial_size[1] + left + right
1095+
1096+
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
1097+
return expected_bboxes, (height, width)
1098+
1099+
1100+
def reference_inputs_pad_bounding_box():
1101+
for bounding_box_loader, padding in itertools.product(
1102+
make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
1103+
):
1104+
yield ArgsKwargs(
1105+
bounding_box_loader,
1106+
format=bounding_box_loader.format,
1107+
spatial_size=bounding_box_loader.spatial_size,
1108+
padding=padding,
1109+
padding_mode="constant",
1110+
)
1111+
1112+
10811113
KERNEL_INFOS.extend(
10821114
[
10831115
KernelInfo(
@@ -1097,6 +1129,8 @@ def sample_inputs_pad_video():
10971129
KernelInfo(
10981130
F.pad_bounding_box,
10991131
sample_inputs_fn=sample_inputs_pad_bounding_box,
1132+
reference_fn=reference_pad_bounding_box,
1133+
reference_inputs_fn=reference_inputs_pad_bounding_box,
11001134
test_marks=[
11011135
xfail_jit_python_scalar_arg("padding"),
11021136
xfail_jit_tuple_instead_of_list("padding"),

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -768,14 +768,11 @@ def pad_bounding_box(
768768

769769
left, right, top, bottom = _parse_pad_padding(padding)
770770

771-
bounding_box = bounding_box.clone()
772-
773-
# this works without conversion since padding only affects xy coordinates
774-
bounding_box[..., 0] += left
775-
bounding_box[..., 1] += top
776771
if format == features.BoundingBoxFormat.XYXY:
777-
bounding_box[..., 2] += left
778-
bounding_box[..., 3] += top
772+
pad = [left, top, left, top]
773+
else:
774+
pad = [left, top, 0, 0]
775+
bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
779776

780777
height, width = spatial_size
781778
height += top + bottom
@@ -821,14 +818,13 @@ def crop_bounding_box(
821818
width: int,
822819
) -> Tuple[torch.Tensor, Tuple[int, int]]:
823820

824-
bounding_box = bounding_box.clone()
825-
826821
# Crop or implicit pad if left and/or top have negative values:
827822
if format == features.BoundingBoxFormat.XYXY:
828-
sub = torch.tensor([left, top, left, top], device=bounding_box.device)
823+
sub = [left, top, left, top]
829824
else:
830-
sub = torch.tensor([left, top, 0, 0], device=bounding_box.device)
831-
bounding_box = bounding_box.sub_(sub)
825+
sub = [left, top, 0, 0]
826+
827+
bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
832828

833829
return bounding_box, (height, width)
834830

0 commit comments

Comments
 (0)