@@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
661661 @pytest .mark .parametrize (
662662 "img_data, expected_mode" ,
663663 [
664- (torch .Tensor (4 , 4 , 1 ).uniform_ ().numpy (), "F " ),
664+ (torch .Tensor (4 , 4 , 1 ).uniform_ ().numpy (), "L " ),
665665 (torch .ByteTensor (4 , 4 , 1 ).random_ (0 , 255 ).numpy (), "L" ),
666666 (torch .ShortTensor (4 , 4 , 1 ).random_ ().numpy (), "I;16" ),
667667 (torch .IntTensor (4 , 4 , 1 ).random_ ().numpy (), "I" ),
@@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
671671 transform = transforms .ToPILImage (mode = expected_mode ) if with_mode else transforms .ToPILImage ()
672672 img = transform (img_data )
673673 assert img .mode == expected_mode
674+ if np .issubdtype (img_data .dtype , np .floating ):
675+ img_data = (img_data * 255 ).astype (np .uint8 )
674676 # note: we explicitly convert img's dtype because pytorch doesn't support uint16
675677 # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
676678 torch .testing .assert_close (img_data [:, :, 0 ], np .asarray (img ).astype (img_data .dtype ))
@@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
741743 @pytest .mark .parametrize (
742744 "img_data, expected_mode" ,
743745 [
744- (torch .Tensor (4 , 4 ).uniform_ ().numpy (), "F " ),
746+ (torch .Tensor (4 , 4 ).uniform_ ().numpy (), "L " ),
745747 (torch .ByteTensor (4 , 4 ).random_ (0 , 255 ).numpy (), "L" ),
746748 (torch .ShortTensor (4 , 4 ).random_ ().numpy (), "I;16" ),
747749 (torch .IntTensor (4 , 4 ).random_ ().numpy (), "I" ),
@@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
751753 transform = transforms .ToPILImage (mode = expected_mode ) if with_mode else transforms .ToPILImage ()
752754 img = transform (img_data )
753755 assert img .mode == expected_mode
756+ if np .issubdtype (img_data .dtype , np .floating ):
757+ img_data = (img_data * 255 ).astype (np .uint8 )
754758 np .testing .assert_allclose (img_data , img )
755759
756760 @pytest .mark .parametrize ("expected_mode" , [None , "RGB" , "HSV" , "YCbCr" ])
@@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self):
874878 trans (np .ones ([4 , 4 , 1 ], np .uint16 ))
875879 with pytest .raises (TypeError , match = reg_msg ):
876880 trans (np .ones ([4 , 4 , 1 ], np .uint32 ))
877- with pytest .raises (TypeError , match = reg_msg ):
878- trans (np .ones ([4 , 4 , 1 ], np .float64 ))
879881
880882 with pytest .raises (ValueError , match = r"pic should be 2/3 dimensional. Got \d+ dimensions." ):
881883 transforms .ToPILImage ()(np .ones ([1 , 4 , 4 , 3 ]))
0 commit comments