Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[proto] Speed-up crop on bboxes and tests
  • Loading branch information
vfdev-5 committed Nov 1, 2022
commit 6a618e2afaad0f755674623e0e6be6419925e945
25 changes: 25 additions & 0 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,29 @@ def sample_inputs_crop_video():
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)


def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):

affine_matrix = np.array(
[
[1, 0, -left],
[0, 1, -top],
],
dtype="float32",
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, affine_matrix=affine_matrix
)
return expected_bboxes, (height, width)


def reference_inputs_crop_bounding_box():
for bounding_box_loader, params in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -875,6 +898,8 @@ def sample_inputs_crop_video():
KernelInfo(
F.crop_bounding_box,
sample_inputs_fn=sample_inputs_crop_bounding_box,
reference_fn=reference_crop_bounding_box,
reference_inputs_fn=reference_inputs_crop_bounding_box,
),
KernelInfo(
F.crop_mask,
Expand Down
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def pad(
crop_image_pil = _FP.crop


def crop_bounding_box(
def crop_bounding_box_old(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
Expand All @@ -827,6 +827,27 @@ def crop_bounding_box(
)


def crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:

bounding_box = bounding_box.clone()

# Crop or implicit pad if left and/or top have negative values:
if format == features.BoundingBoxFormat.XYXY:
sub = torch.tensor([left, top, left, top])
else:
sub = torch.tensor([left, top, 0, 0])
bounding_box.sub_(sub)

return bounding_box, (height, width)


def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width)

Expand Down