@@ -1108,6 +1108,18 @@ def elastic_image_pil(
11081108 return to_pil_image (output , mode = image .mode )
11091109
11101110
1111+ def _create_identity_grid (size : Tuple [int , int ], device : torch .device ) -> torch .Tensor :
1112+ sy , sx = size
1113+ base_grid = torch .empty (1 , sy , sx , 2 , device = device )
1114+ x_grid = torch .linspace ((- sx + 1 ) / sx , (sx - 1 ) / sx , sx , device = device )
1115+ base_grid [..., 0 ].copy_ (x_grid )
1116+
1117+ y_grid = torch .linspace ((- sy + 1 ) / sy , (sy - 1 ) / sy , sy , device = device ).unsqueeze_ (- 1 )
1118+ base_grid [..., 1 ].copy_ (y_grid )
1119+
1120+ return base_grid
1121+
1122+
11111123def elastic_bounding_box (
11121124 bounding_box : torch .Tensor ,
11131125 format : features .BoundingBoxFormat ,
@@ -1125,22 +1137,24 @@ def elastic_bounding_box(
11251137 # Or add spatial_size arg and check displacement shape
11261138 spatial_size = displacement .shape [- 3 ], displacement .shape [- 2 ]
11271139
1128- id_grid = _FT . _create_identity_grid (list ( spatial_size )). to ( bounding_box .device )
1140+ id_grid = _create_identity_grid (spatial_size , bounding_box .device )
11291141 # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
11301142 # This is not an exact inverse of the grid
1131- inv_grid = id_grid - displacement
1143+ inv_grid = id_grid . sub_ ( displacement )
11321144
11331145 # Get points from bboxes
11341146 points = bounding_box [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
1135- index_x = torch .floor (points [:, 0 ] + 0.5 ).to (dtype = torch .long )
1136- index_y = torch .floor (points [:, 1 ] + 0.5 ).to (dtype = torch .long )
1147+ if points .is_floating_point ():
1148+ points = points .ceil_ ()
1149+ index_xy = points .to (dtype = torch .long )
1150+ index_x , index_y = index_xy [:, 0 ], index_xy [:, 1 ]
1151+
11371152 # Transform points:
11381153 t_size = torch .tensor (spatial_size [::- 1 ], device = displacement .device , dtype = displacement .dtype )
1139- transformed_points = ( inv_grid [0 , index_y , index_x , :] + 1 ) * 0.5 * t_size - 0.5
1154+ transformed_points = inv_grid [0 , index_y , index_x , :]. add_ ( 1 ). mul_ ( 0.5 * t_size ). sub_ ( 0.5 )
11401155
11411156 transformed_points = transformed_points .reshape (- 1 , 4 , 2 )
1142- out_bbox_mins , _ = torch .min (transformed_points , dim = 1 )
1143- out_bbox_maxs , _ = torch .max (transformed_points , dim = 1 )
1157+ out_bbox_mins , out_bbox_maxs = torch .aminmax (transformed_points , dim = 1 )
11441158 out_bboxes = torch .cat ([out_bbox_mins , out_bbox_maxs ], dim = 1 ).to (bounding_box .dtype )
11451159
11461160 return convert_format_bounding_box (
0 commit comments