Skip to content

Commit 7320648

Browse files
authored
[proto] Small optims for elastic op on bboxes (#6897)
* [proto] Small optims for elastic op on bboxes * More inplace ops according to the review * Create grid on device directly. This should be faster * PR Review update. Apply ceil on float input
1 parent 9b0da0c commit 7320648

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,18 @@ def elastic_image_pil(
11081108
return to_pil_image(output, mode=image.mode)
11091109

11101110

1111+
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
1112+
sy, sx = size
1113+
base_grid = torch.empty(1, sy, sx, 2, device=device)
1114+
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
1115+
base_grid[..., 0].copy_(x_grid)
1116+
1117+
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
1118+
base_grid[..., 1].copy_(y_grid)
1119+
1120+
return base_grid
1121+
1122+
11111123
def elastic_bounding_box(
11121124
bounding_box: torch.Tensor,
11131125
format: features.BoundingBoxFormat,
@@ -1125,22 +1137,24 @@ def elastic_bounding_box(
11251137
# Or add spatial_size arg and check displacement shape
11261138
spatial_size = displacement.shape[-3], displacement.shape[-2]
11271139

1128-
id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device)
1140+
id_grid = _create_identity_grid(spatial_size, bounding_box.device)
11291141
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
11301142
# This is not an exact inverse of the grid
1131-
inv_grid = id_grid - displacement
1143+
inv_grid = id_grid.sub_(displacement)
11321144

11331145
# Get points from bboxes
11341146
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1135-
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
1136-
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
1147+
if points.is_floating_point():
1148+
points = points.ceil_()
1149+
index_xy = points.to(dtype=torch.long)
1150+
index_x, index_y = index_xy[:, 0], index_xy[:, 1]
1151+
11371152
# Transform points:
11381153
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
1139-
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
1154+
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
11401155

11411156
transformed_points = transformed_points.reshape(-1, 4, 2)
1142-
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
1143-
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
1157+
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
11441158
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
11451159

11461160
return convert_format_bounding_box(

0 commit comments

Comments
 (0)