@@ -99,7 +99,7 @@ def _script(obj):
9999 return torch .jit .script (obj )
100100 except Exception as error :
101101 name = getattr (obj , "__name__" , obj .__class__ .__name__ )
102- raise AssertionError (f"Trying to `torch.jit.script` ' { name } ' raised the error above." ) from error
102+ raise AssertionError (f"Trying to `torch.jit.script` ` { name } ` raised the error above." ) from error
103103
104104
105105def _check_kernel_scripted_vs_eager (kernel , input , * args , rtol , atol , ** kwargs ):
@@ -553,10 +553,12 @@ def affine_bounding_boxes(bounding_boxes):
553553
554554class TestResize :
555555 INPUT_SIZE = (17 , 11 )
556- OUTPUT_SIZES = [17 , [17 ], (17 ,), [12 , 13 ], (12 , 13 )]
556+ OUTPUT_SIZES = [17 , [17 ], (17 ,), None , [12 , 13 ], (12 , 13 )]
557557
558558 def _make_max_size_kwarg (self , * , use_max_size , size ):
559- if use_max_size :
559+ if size is None :
560+ max_size = min (list (self .INPUT_SIZE ))
561+ elif use_max_size :
560562 if not (isinstance (size , int ) or len (size ) == 1 ):
561563 # This would result in an `ValueError`
562564 return None
@@ -568,10 +570,13 @@ def _make_max_size_kwarg(self, *, use_max_size, size):
568570 return dict (max_size = max_size )
569571
570572 def _compute_output_size (self , * , input_size , size , max_size ):
571- if not (isinstance (size , int ) or len (size ) == 1 ):
573+ if size is None :
574+ size = max_size
575+
576+ elif not (isinstance (size , int ) or len (size ) == 1 ):
572577 return tuple (size )
573578
574- if not isinstance (size , int ):
579+ elif not isinstance (size , int ):
575580 size = size [0 ]
576581
577582 old_height , old_width = input_size
@@ -658,10 +663,13 @@ def test_kernel_video(self):
658663 [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask , make_video ],
659664 )
660665 def test_functional (self , size , make_input ):
666+ max_size_kwarg = self ._make_max_size_kwarg (use_max_size = size is None , size = size )
667+
661668 check_functional (
662669 F .resize ,
663670 make_input (self .INPUT_SIZE ),
664671 size = size ,
672+ ** max_size_kwarg ,
665673 antialias = True ,
666674 check_scripted_smoke = not isinstance (size , int ),
667675 )
@@ -695,11 +703,13 @@ def test_functional_signature(self, kernel, input_type):
695703 ],
696704 )
697705 def test_transform (self , size , device , make_input ):
706+ max_size_kwarg = self ._make_max_size_kwarg (use_max_size = size is None , size = size )
707+
698708 check_transform (
699- transforms .Resize (size = size , antialias = True ),
709+ transforms .Resize (size = size , ** max_size_kwarg , antialias = True ),
700710 make_input (self .INPUT_SIZE , device = device ),
701711 # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
702- check_v1_compatibility = dict (rtol = 0 , atol = 1 ),
712+ check_v1_compatibility = dict (rtol = 0 , atol = 1 ) if size is not None else False ,
703713 )
704714
705715 def _check_output_size (self , input , output , * , size , max_size ):
@@ -801,7 +811,11 @@ def test_functional_pil_antialias_warning(self):
801811 ],
802812 )
803813 def test_max_size_error (self , size , make_input ):
804- if isinstance (size , int ) or len (size ) == 1 :
814+ if size is None :
815+ # value can be anything other than an integer
816+ max_size = None
817+ match = "max_size must be an integer when size is None"
818+ elif isinstance (size , int ) or len (size ) == 1 :
805819 max_size = (size if isinstance (size , int ) else size [0 ]) - 1
806820 match = "must be strictly greater than the requested size"
807821 else :
@@ -812,6 +826,37 @@ def test_max_size_error(self, size, make_input):
812826 with pytest .raises (ValueError , match = match ):
813827 F .resize (make_input (self .INPUT_SIZE ), size = size , max_size = max_size , antialias = True )
814828
829+ if isinstance (size , list ) and len (size ) != 1 :
830+ with pytest .raises (ValueError , match = "max_size should only be passed if size is None or specifies" ):
831+ F .resize (make_input (self .INPUT_SIZE ), size = size , max_size = 500 )
832+
833+ @pytest .mark .parametrize (
834+ "input_size, max_size, expected_size" ,
835+ [
836+ ((10 , 10 ), 10 , (10 , 10 )),
837+ ((10 , 20 ), 40 , (20 , 40 )),
838+ ((20 , 10 ), 40 , (40 , 20 )),
839+ ((10 , 20 ), 10 , (5 , 10 )),
840+ ((20 , 10 ), 10 , (10 , 5 )),
841+ ],
842+ )
843+ @pytest .mark .parametrize (
844+ "make_input" ,
845+ [
846+ make_image_tensor ,
847+ make_image_pil ,
848+ make_image ,
849+ make_bounding_boxes ,
850+ make_segmentation_mask ,
851+ make_detection_masks ,
852+ make_video ,
853+ ],
854+ )
855+ def test_resize_size_none (self , input_size , max_size , expected_size , make_input ):
856+ img = make_input (input_size )
857+ out = F .resize (img , size = None , max_size = max_size )
858+ assert F .get_size (out )[- 2 :] == list (expected_size )
859+
815860 @pytest .mark .parametrize ("interpolation" , INTERPOLATION_MODES )
816861 @pytest .mark .parametrize (
817862 "make_input" ,
@@ -834,7 +879,7 @@ def test_interpolation_int(self, interpolation, make_input):
834879 assert_equal (actual , expected )
835880
836881 def test_transform_unknown_size_error (self ):
837- with pytest .raises (ValueError , match = "size can either be an integer or a sequence of one or two integers" ):
882+ with pytest .raises (ValueError , match = "size can be an integer, a sequence of one or two integers, or None " ):
838883 transforms .Resize (size = object ())
839884
840885 @pytest .mark .parametrize (
0 commit comments