@@ -49,3 +49,80 @@ def __iter__(self):
4949
5050 def __len__ (self ):
5151 return self .num_samples
52+
53+
54+ class RepeatAugSampler (Sampler ):
55+ """Sampler that restricts data loading to a subset of the dataset for distributed,
56+ with repeated augmentation.
57+ It ensures that different each augmented version of a sample will be visible to a
58+ different process (GPU). Heavily based on torch.utils.data.DistributedSampler
59+
60+ This sampler was taken from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
61+ Used in
62+ Copyright (c) 2015-present, Facebook, Inc.
63+ """
64+
65+ def __init__ (
66+ self ,
67+ dataset ,
68+ num_replicas = None ,
69+ rank = None ,
70+ shuffle = True ,
71+ num_repeats = 3 ,
72+ selected_round = 256 ,
73+ selected_ratio = 0 ,
74+ ):
75+ if num_replicas is None :
76+ if not dist .is_available ():
77+ raise RuntimeError ("Requires distributed package to be available" )
78+ num_replicas = dist .get_world_size ()
79+ if rank is None :
80+ if not dist .is_available ():
81+ raise RuntimeError ("Requires distributed package to be available" )
82+ rank = dist .get_rank ()
83+ self .dataset = dataset
84+ self .num_replicas = num_replicas
85+ self .rank = rank
86+ self .shuffle = shuffle
87+ self .num_repeats = num_repeats
88+ self .epoch = 0
89+ self .num_samples = int (math .ceil (len (self .dataset ) * num_repeats / self .num_replicas ))
90+ self .total_size = self .num_samples * self .num_replicas
91+ # Determine the number of samples to select per epoch for each rank.
92+ # num_selected logic defaults to be the same as original RASampler impl, but this one can be tweaked
93+ # via selected_ratio and selected_round args.
94+ selected_ratio = selected_ratio or num_replicas # ratio to reduce selected samples by, num_replicas if 0
95+ if selected_round :
96+ self .num_selected_samples = int (math .floor (
97+ len (self .dataset ) // selected_round * selected_round / selected_ratio ))
98+ else :
99+ self .num_selected_samples = int (math .ceil (len (self .dataset ) / selected_ratio ))
100+
101+ def __iter__ (self ):
102+ # deterministically shuffle based on epoch
103+ g = torch .Generator ()
104+ g .manual_seed (self .epoch )
105+ if self .shuffle :
106+ indices = torch .randperm (len (self .dataset ), generator = g ).tolist ()
107+ else :
108+ indices = list (range (len (self .dataset )))
109+
110+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
111+ indices = [x for x in indices for _ in range (self .num_repeats )]
112+ # add extra samples to make it evenly divisible
113+ padding_size = self .total_size - len (indices )
114+ indices += indices [:padding_size ]
115+ assert len (indices ) == self .total_size
116+
117+ # subsample per rank
118+ indices = indices [self .rank :self .total_size :self .num_replicas ]
119+ assert len (indices ) == self .num_samples
120+
121+ # return up to num selected samples
122+ return iter (indices [:self .num_selected_samples ])
123+
124+ def __len__ (self ):
125+ return self .num_selected_samples
126+
127+ def set_epoch (self , epoch ):
128+ self .epoch = epoch
0 commit comments