@@ -39,19 +39,30 @@ def __call__(self, pic):
3939 if isinstance (pic , np .ndarray ):
4040 # handle numpy array
4141 img = torch .from_numpy (pic .transpose ((2 , 0 , 1 )))
42+ # backard compability
43+ return img .float ().div (255 )
44+ # handle PIL Image
45+ if pic .mode == 'I' :
46+ img = torch .from_numpy (np .array (pic , np .int32 ))
47+ elif pic .mode == 'I;16' :
48+ img = torch .from_numpy (np .array (pic , np .int16 ))
4249 else :
43- # handle PIL Image
4450 img = torch .ByteTensor (torch .ByteStorage .from_buffer (pic .tobytes ()))
45- # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
46- if pic .mode == 'YCbCr' :
47- nchannel = 3
48- else :
49- nchannel = len (pic .mode )
50- img = img .view (pic .size [1 ], pic .size [0 ], nchannel )
51- # put it from HWC to CHW format
52- # yikes, this transpose takes 80% of the loading time/CPU
53- img = img .transpose (0 , 1 ).transpose (0 , 2 ).contiguous ()
54- return img .float ().div (255 )
51+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
52+ if pic .mode == 'YCbCr' :
53+ nchannel = 3
54+ elif pic .mode == 'I;16' :
55+ nchannel = 1
56+ else :
57+ nchannel = len (pic .mode )
58+ img = img .view (pic .size [1 ], pic .size [0 ], nchannel )
59+ # put it from HWC to CHW format
60+ # yikes, this transpose takes 80% of the loading time/CPU
61+ img = img .transpose (0 , 1 ).transpose (0 , 2 ).contiguous ()
62+ if isinstance (img , torch .ByteTensor ):
63+ return img .float ().div (255 )
64+ else :
65+ return img
5566
5667
5768class ToPILImage (object ):
@@ -67,7 +78,6 @@ def __call__(self, pic):
6778 if torch .is_tensor (pic ):
6879 npimg = np .transpose (pic .numpy (), (1 , 2 , 0 ))
6980 assert isinstance (npimg , np .ndarray ), 'pic should be Tensor or ndarray'
70-
7181 if npimg .shape [2 ] == 1 :
7282 npimg = npimg [:, :, 0 ]
7383
@@ -83,7 +93,6 @@ def __call__(self, pic):
8393 if npimg .dtype == np .uint8 :
8494 mode = 'RGB'
8595 assert mode is not None , '{} is not supported' .format (npimg .dtype )
86-
8796 return Image .fromarray (npimg , mode = mode )
8897
8998
0 commit comments