@@ -223,20 +223,16 @@ def __init__(
223223 _check_padding_arg (padding )
224224 _check_padding_mode_arg (padding_mode )
225225
226+ # This cast does Sequence[int] -> List[int] and is required to make mypy happy
227+ if not isinstance (padding , int ):
228+ padding = list (padding )
226229 self .padding = padding
227230 self .fill = _setup_fill_arg (fill )
228231 self .padding_mode = padding_mode
229232
230233 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
231234 fill = self .fill [type (inpt )]
232-
233- # This cast does Sequence[int] -> List[int] and is required to make mypy happy
234- padding = self .padding
235- if not isinstance (padding , int ):
236- padding = list (padding )
237-
238- fill = F ._geometry ._convert_fill_arg (fill )
239- return F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
235+ return F .pad (inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode )
240236
241237
242238class RandomZoomOut (_RandomApplyTransform ):
@@ -274,7 +270,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
274270
275271 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
276272 fill = self .fill [type (inpt )]
277- fill = F ._geometry ._convert_fill_arg (fill )
278273 return F .pad (inpt , ** params , fill = fill )
279274
280275
@@ -300,12 +295,11 @@ def __init__(
300295 self .center = center
301296
302297 def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
303- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
298+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
304299 return dict (angle = angle )
305300
306301 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
307302 fill = self .fill [type (inpt )]
308- fill = F ._geometry ._convert_fill_arg (fill )
309303 return F .rotate (
310304 inpt ,
311305 ** params ,
@@ -358,7 +352,7 @@ def __init__(
358352 def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
359353 height , width = query_spatial_size (flat_inputs )
360354
361- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
355+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
362356 if self .translate is not None :
363357 max_dx = float (self .translate [0 ] * width )
364358 max_dy = float (self .translate [1 ] * height )
@@ -369,22 +363,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
369363 translate = (0 , 0 )
370364
371365 if self .scale is not None :
372- scale = float ( torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item () )
366+ scale = torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item ()
373367 else :
374368 scale = 1.0
375369
376370 shear_x = shear_y = 0.0
377371 if self .shear is not None :
378- shear_x = float ( torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item () )
372+ shear_x = torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item ()
379373 if len (self .shear ) == 4 :
380- shear_y = float ( torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item () )
374+ shear_y = torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item ()
381375
382376 shear = (shear_x , shear_y )
383377 return dict (angle = angle , translate = translate , scale = scale , shear = shear )
384378
385379 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
386380 fill = self .fill [type (inpt )]
387- fill = F ._geometry ._convert_fill_arg (fill )
388381 return F .affine (
389382 inpt ,
390383 ** params ,
@@ -478,8 +471,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
478471 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
479472 if params ["needs_pad" ]:
480473 fill = self .fill [type (inpt )]
481- fill = F ._geometry ._convert_fill_arg (fill )
482-
483474 inpt = F .pad (inpt , padding = params ["padding" ], fill = fill , padding_mode = self .padding_mode )
484475
485476 if params ["needs_crop" ]:
@@ -512,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
512503
513504 half_height = height // 2
514505 half_width = width // 2
506+ bound_height = int (distortion_scale * half_height ) + 1
507+ bound_width = int (distortion_scale * half_width ) + 1
515508 topleft = [
516- int (torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)). item ( )),
517- int (torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)). item ( )),
509+ int (torch .randint (0 , bound_width , size = (1 ,))),
510+ int (torch .randint (0 , bound_height , size = (1 ,))),
518511 ]
519512 topright = [
520- int (torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)). item ( )),
521- int (torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)). item ( )),
513+ int (torch .randint (width - bound_width , width , size = (1 ,))),
514+ int (torch .randint (0 , bound_height , size = (1 ,))),
522515 ]
523516 botright = [
524- int (torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)). item ( )),
525- int (torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)). item ( )),
517+ int (torch .randint (width - bound_width , width , size = (1 ,))),
518+ int (torch .randint (height - bound_height , height , size = (1 ,))),
526519 ]
527520 botleft = [
528- int (torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)). item ( )),
529- int (torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)). item ( )),
521+ int (torch .randint (0 , bound_width , size = (1 ,))),
522+ int (torch .randint (height - bound_height , height , size = (1 ,))),
530523 ]
531524 startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
532525 endpoints = [topleft , topright , botright , botleft ]
@@ -535,7 +528,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
535528
536529 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
537530 fill = self .fill [type (inpt )]
538- fill = F ._geometry ._convert_fill_arg (fill )
539531 return F .perspective (
540532 inpt ,
541533 ** params ,
@@ -584,7 +576,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
584576
585577 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
586578 fill = self .fill [type (inpt )]
587- fill = F ._geometry ._convert_fill_arg (fill )
588579 return F .elastic (
589580 inpt ,
590581 ** params ,
@@ -855,7 +846,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
855846
856847 if params ["needs_pad" ]:
857848 fill = self .fill [type (inpt )]
858- fill = F ._geometry ._convert_fill_arg (fill )
859849 inpt = F .pad (inpt , params ["padding" ], fill = fill , padding_mode = self .padding_mode )
860850
861851 return inpt
0 commit comments