@@ -66,7 +66,7 @@ def __init__(
6666 return_indices = True ,
6767 force_quantization_f32 = True ,
6868 preserve_symmetry : bool = False ,
69- noise_approx_prob = 0.0 ,
69+ noise_dropout = 0.0 ,
7070 ):
7171 super ().__init__ ()
7272
@@ -79,7 +79,7 @@ def __init__(
7979 self .scale = scale
8080
8181 self .preserve_symmetry = preserve_symmetry
82- self .noise_approx_prob = noise_approx_prob
82+ self .noise_dropout = noise_dropout
8383
8484 codebook_dim = len (levels )
8585 self .codebook_dim = codebook_dim
@@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z):
129129 bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
130130 return scale * bracket - 1.0
131131
132- def noise_approx_bound (self , z ):
133- """
134- simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135- """
136- noise = torch .empty_like (z ).uniform_ (- 1 , 1 )
137- return torch .tanh (z ) + noise / (self ._levels - 1 )
138-
139132 def quantize (self , z , preserve_symmetry = False ):
140133 """ Quantizes z, returns quantized zhat, same shape as z. """
141- if self .training and random .random () < self .noise_approx_prob :
142- bounded = self .noise_approx_bound (z )
134+
135+ half_width = self ._levels // 2
136+
137+ if self .training :
138+ unquantized = z
139+
140+ # determine where to quantize elementwise
141+
142+ quantize_mask = torch .bernoulli (
143+ torch .full ([z .shape [0 ], 1 , 1 , 1 ], self .noise_dropout , device = z .device )
144+ ).bool ().expand_as (z )
145+
146+ if preserve_symmetry :
147+ quantized = round_ste (self .symmetry_preserving_bound (z )) / half_width
148+ else :
149+ quantized = round_ste (self .bound (z )) / half_width
150+ quantized = torch .where (quantize_mask , unquantized , quantized )
151+
152+ # determine where to add a random offset elementwise
153+
154+ offset_mask = torch .bernoulli (
155+ torch .full ([z .shape [0 ], 1 , 1 , 1 ], self .noise_dropout , device = z .device )
156+ ).bool ().expand_as (z )
157+
158+ offset = (torch .rand_like (z ) - 0.5 ) / half_width
159+ quantized = torch .where (offset_mask , unquantized + offset , quantized )
143160 elif preserve_symmetry :
144- bounded = self .symmetry_preserving_bound (z )
161+ quantized = round_ste ( self .symmetry_preserving_bound (z )) / half_width
145162 else :
146- bounded = self .bound (z )
147- quantized = round_ste (bounded )
148- half_width = self ._levels // 2 # Renormalize to [-1, 1].
149- return quantized / half_width
163+ quantized = round_ste (self .bound (z )) / half_width
164+
165+ return quantized
150166
151167 def _scale_and_shift (self , zhat_normalized ):
152168 half_width = self ._levels // 2
0 commit comments