Skip to content

Commit c76770f

Browse files
committed
Merge commit 'dfca8dfdc5988813ed5673589ffa4fdd1c4f3d2d'
2 parents da72583 + dfca8df commit c76770f

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

torch/lib/THC/THCTensorRandom.cuh

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ sampleMultinomialOnce(long* dest,
105105
T* sampled,
106106
T* dist) {
107107
extern __shared__ __align__(sizeof(AccT)) unsigned char my_smem[];
108+
__shared__ bool found;
108109

109110
// Shared Memory hold blockdim.x T for holding the cumulative sum,
110111
// blockDim.x AccT for normalizing the probabilities,
@@ -153,8 +154,9 @@ sampleMultinomialOnce(long* dest,
153154

154155
int chunks = THCCeilDiv(categories, (int) blockDim.x);
155156
T prevHighProb = zero;
157+
found = false;
156158

157-
for (int chunk = 0; chunk < chunks; ++chunk) {
159+
for (int chunk = 0; chunk < chunks && !found; ++chunk) {
158160
// All threads in bounds load a value
159161
int cat = chunk * blockDim.x + threadIdx.x;
160162

@@ -197,15 +199,30 @@ sampleMultinomialOnce(long* dest,
197199
if (inBucket) {
198200
// We're done; we have the sample
199201
// Torch indices are 1-based
200-
// FIXME: broadcast exit flag?
201202
dest[curDist] = cat + TH_INDEX_BASE;
203+
found = true;
202204
}
203205

204206
// Store the previous scan's high value for future use
205207
prevHighProb = THCNumerics<T>::add(prevHighProb, smem[blockDim.x - 1]);
206208

207209
__syncthreads();
208210
}
211+
212+
if (threadIdx.x == 0 && !found) {
213+
// This should address a rare bug where we don't select a valid index. This likely occurs when
214+
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
215+
// and our uniform sample is greater than this value. In this case we likely have unitialized memory
216+
// in dest[curDist]. So basically we will loop through the distribution and pick the largest index
217+
// where the distribution is non-zero. This is obviously terribly inefficient, but due to the
218+
// rarity in which this occurs, this should not be an issue.
219+
for (int cat = categories - 1; cat >= 0; --cat) {
220+
if (THCNumerics<T>::gt(dist[curDist * categories + cat], zero)) {
221+
dest[curDist] = cat + TH_INDEX_BASE;
222+
break;
223+
}
224+
}
225+
}
209226
}
210227
}
211228

0 commit comments

Comments
 (0)