11import torch
22import math
3+ irange = range
34
45
5- def make_grid (tensor , nrow = 8 , padding = 2 ):
6+ def make_grid (tensor , nrow = 8 , padding = 2 ,
7+ normalize = False , range = None , scale_each = False ):
68 """
79 Given a 4D mini-batch Tensor of shape (B x C x H x W),
810 or a list of images all of the same size,
911 makes a grid of images
12+
13+ normalize=True will shift the image to the range (0, 1),
14+ by subtracting the minimum and dividing by the maximum pixel value.
15+
16+ if range=(min, max) where min and max are numbers, then these numbers are used to
17+ normalize the image.
18+
19+ scale_each=True will scale each image in the batch of images separately rather than
20+ computing the (min, max) over all images.
21+
22+ [Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
1023 """
11- tensorlist = None
24+ # if list of tensors, convert to a 4D mini-batch Tensor
1225 if isinstance (tensor , list ):
1326 tensorlist = tensor
1427 numImages = len (tensorlist )
1528 size = torch .Size (torch .Size ([long (numImages )]) + tensorlist [0 ].size ())
1629 tensor = tensorlist [0 ].new (size )
17- for i in range (numImages ):
30+ for i in irange (numImages ):
1831 tensor [i ].copy_ (tensorlist [i ])
32+
1933 if tensor .dim () == 2 : # single image H x W
2034 tensor = tensor .view (1 , tensor .size (0 ), tensor .size (1 ))
2135 if tensor .dim () == 3 : # single image
22- if tensor .size (0 ) == 1 :
36+ if tensor .size (0 ) == 1 : # if single-channel, convert to 3-channel
2337 tensor = torch .cat ((tensor , tensor , tensor ), 0 )
2438 return tensor
2539 if tensor .dim () == 4 and tensor .size (1 ) == 1 : # single-channel images
2640 tensor = torch .cat ((tensor , tensor , tensor ), 1 )
41+
42+ if normalize is True :
43+ if range is not None :
44+ assert isinstance (range , tuple ), \
45+ "range has to be a tuple (min, max) if specified. min and max are numbers"
46+
47+ def norm_ip (img , min , max ):
48+ img .clamp_ (min = min , max = max )
49+ img .add_ (- min ).div_ (max - min )
50+
51+ def norm_range (t , range ):
52+ if range is not None :
53+ norm_ip (t , range [0 ], range [1 ])
54+ else :
55+ norm_ip (t , t .min (), t .max ())
56+
57+ if scale_each is True :
58+ for t in tensor : # loop over mini-batch dimension
59+ norm_range (t , range )
60+ else :
61+ norm_range (tensor , range )
62+
2763 # make the mini-batch of images into a grid
2864 nmaps = tensor .size (0 )
2965 xmaps = min (nrow , nmaps )
3066 ymaps = int (math .ceil (float (nmaps ) / xmaps ))
3167 height , width = int (tensor .size (2 ) + padding ), int (tensor .size (3 ) + padding )
32- grid = tensor .new (3 , height * ymaps , width * xmaps ).fill_ (tensor . max () )
68+ grid = tensor .new (3 , height * ymaps , width * xmaps ).fill_ (0 )
3369 k = 0
34- for y in range (ymaps ):
35- for x in range (xmaps ):
70+ for y in irange (ymaps ):
71+ for x in irange (xmaps ):
3672 if k >= nmaps :
3773 break
3874 grid .narrow (1 , y * height + 1 + padding // 2 , height - padding )\
@@ -42,14 +78,18 @@ def make_grid(tensor, nrow=8, padding=2):
4278 return grid
4379
4480
45- def save_image (tensor , filename , nrow = 8 , padding = 2 ):
81+ def save_image (tensor , filename , nrow = 8 , padding = 2 ,
82+ normalize = False , range = None , scale_each = False ):
4683 """
4784 Saves a given Tensor into an image file.
48- If given a mini-batch tensor, will save the tensor as a grid of images.
85+ If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
86+ All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
87+ more details
4988 """
5089 from PIL import Image
5190 tensor = tensor .cpu ()
52- grid = make_grid (tensor , nrow = nrow , padding = padding )
91+ grid = make_grid (tensor , nrow = nrow , padding = padding ,
92+ normalize = normalize , range = range , scale_each = scale_each )
5393 ndarr = grid .mul (255 ).byte ().transpose (0 , 2 ).transpose (0 , 1 ).numpy ()
5494 im = Image .fromarray (ndarr )
5595 im .save (filename )
0 commit comments