1616
1717from einops import rearrange , pack , unpack
1818
19+ import random
20+
1921# helper functions
2022
2123def exists (v ):
@@ -62,9 +64,12 @@ def __init__(
6264 channel_first : bool = False ,
6365 projection_has_bias : bool = True ,
6466 return_indices = True ,
65- force_quantization_f32 = True
67+ force_quantization_f32 = True ,
68+ preserve_symmetry : bool = False ,
69+ noise_approx_prob = 0.0 ,
6670 ):
6771 super ().__init__ ()
72+
6873 _levels = torch .tensor (levels , dtype = int32 )
6974 self .register_buffer ("_levels" , _levels , persistent = False )
7075
@@ -73,6 +78,9 @@ def __init__(
7378
7479 self .scale = scale
7580
81+ self .preserve_symmetry = preserve_symmetry
82+ self .noise_approx_prob = noise_approx_prob
83+
7684 codebook_dim = len (levels )
7785 self .codebook_dim = codebook_dim
7886
@@ -110,12 +118,36 @@ def bound(self, z, eps: float = 1e-3):
110118 shift = (offset / half_l ).atanh ()
111119 return (z + shift ).tanh () * half_l - offset
112120
113- def quantize (self , z ):
121+ # symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
122+
123+ def symmetry_preserving_bound (self , z ):
124+ """
125+ QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
126+ """
127+ levels_minus_1 = (self ._levels - 1 )
128+ scale = 2.0 / levels_minus_1
129+ bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
130+ return scale * bracket - 1.0
131+
132+ def noise_approx_bound (self , z ):
133+ """
134+ simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135+ """
136+ noise = torch .empty_like (z ).uniform_ (- 1 , 1 )
137+ return torch .tanh (z ) + noise / (self ._levels - 1 )
138+
139+ def quantize (self , z , preserve_symmetry = False ):
114140 """ Quantizes z, returns quantized zhat, same shape as z. """
115- quantized = round_ste (self .bound (z ))
141+ if self .training and random .random () < self .noise_approx_prob :
142+ bounded = self .noise_approx_bound (z )
143+ elif preserve_symmetry :
144+ bounded = self .symmetry_preserving_bound (z )
145+ else :
146+ bounded = self .bound (z )
147+ quantized = round_ste (bounded )
116148 half_width = self ._levels // 2 # Renormalize to [-1, 1].
117149 return quantized / half_width
118-
150+
119151 def _scale_and_shift (self , zhat_normalized ):
120152 half_width = self ._levels // 2
121153 return (zhat_normalized * half_width ) + half_width
@@ -194,7 +226,7 @@ def forward(self, z):
194226 if force_f32 and orig_dtype not in self .allowed_dtypes :
195227 z = z .float ()
196228
197- codes = self .quantize (z )
229+ codes = self .quantize (z , preserve_symmetry = self . preserve_symmetry )
198230
199231 # returning indices could be optional
200232
@@ -205,7 +237,7 @@ def forward(self, z):
205237
206238 codes = rearrange (codes , 'b n c d -> b n (c d)' )
207239
208- codes = codes .type (orig_dtype )
240+ codes = codes .to (orig_dtype )
209241
210242 # project out
211243
0 commit comments