@@ -105,6 +105,7 @@ sampleMultinomialOnce(long* dest,
105105 T* sampled,
106106 T* dist) {
107107 extern __shared__ __align__ (sizeof (AccT)) unsigned char my_smem[];
108+ __shared__ bool found;
108109
109110 // Shared Memory hold blockdim.x T for holding the cumulative sum,
110111 // blockDim.x AccT for normalizing the probabilities,
@@ -153,8 +154,9 @@ sampleMultinomialOnce(long* dest,
153154
154155 int chunks = THCCeilDiv (categories, (int ) blockDim .x );
155156 T prevHighProb = zero;
157+ found = false ;
156158
157- for (int chunk = 0 ; chunk < chunks; ++chunk) {
159+ for (int chunk = 0 ; chunk < chunks && !found ; ++chunk) {
158160 // All threads in bounds load a value
159161 int cat = chunk * blockDim .x + threadIdx .x ;
160162
@@ -197,15 +199,30 @@ sampleMultinomialOnce(long* dest,
197199 if (inBucket) {
198200 // We're done; we have the sample
199201 // Torch indices are 1-based
200- // FIXME: broadcast exit flag?
201202 dest[curDist] = cat + TH_INDEX_BASE;
203+ found = true ;
202204 }
203205
204206 // Store the previous scan's high value for future use
205207 prevHighProb = THCNumerics<T>::add (prevHighProb, smem[blockDim .x - 1 ]);
206208
207209 __syncthreads ();
208210 }
211+
212+ if (threadIdx .x == 0 && !found) {
213+ // This should address a rare bug where we don't select a valid index. This likely occurs when
214+ // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
215+ // and our uniform sample is greater than this value. In this case we likely have unitialized memory
216+ // in dest[curDist]. So basically we will loop through the distribution and pick the largest index
217+ // where the distribution is non-zero. This is obviously terribly inefficient, but due to the
218+ // rarity in which this occurs, this should not be an issue.
219+ for (int cat = categories - 1 ; cat >= 0 ; --cat) {
220+ if (THCNumerics<T>::gt (dist[curDist * categories + cat], zero)) {
221+ dest[curDist] = cat + TH_INDEX_BASE;
222+ break ;
223+ }
224+ }
225+ }
209226 }
210227}
211228
0 commit comments