Skip to content

Commit ffc7911

Browse files
committed
Merge commit 'd8ae7893e056ebf4e7a5e96bab2c3b69f196ddfd'
2 parents ff1fde6 + d8ae789 commit ffc7911

File tree

1 file changed

+14
-29
lines changed

1 file changed

+14
-29
lines changed

torch/lib/THC/THCSortUtils.cuh

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
6565

6666
#pragma unroll
6767
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
68-
69-
// Single warp per slice is completely synchronous
70-
if (Power2SortSize > 64) {
71-
__syncthreads();
72-
}
73-
68+
69+
__syncthreads();
70+
7471
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
7572
bitonicSwap<Comparator, K, V>(
7673
keys[pos], values[pos], valid[pos],
@@ -81,22 +78,18 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
8178

8279
#pragma unroll
8380
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
84-
// Single warp per slice is completely synchronous
85-
if (Power2SortSize > 64) {
86-
__syncthreads();
87-
}
88-
81+
82+
__syncthreads();
83+
8984
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
9085
bitonicSwap<Comparator, K, V>(
9186
keys[pos], values[pos], valid[pos],
9287
keys[pos + stride], values[pos + stride], valid[pos + stride],
9388
false, comp);
9489
}
9590

96-
// Single warp per slice is completely synchronous
97-
if (Power2SortSize > 64) {
98-
__syncthreads();
99-
}
91+
__syncthreads();
92+
10093
}
10194

10295
template <typename Comparator, typename K,
@@ -111,11 +104,8 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
111104
#pragma unroll
112105
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
113106

114-
// Single warp per slice is completely synchronous
115-
if (Power2SortSize > 64) {
116-
__syncthreads();
117-
}
118-
107+
__syncthreads();
108+
119109
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
120110
bitonicSwapKeys<Comparator, K>(
121111
keys[pos], valid[pos],
@@ -126,22 +116,17 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
126116

127117
#pragma unroll
128118
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
129-
// Single warp per slice is completely synchronous
130-
if (Power2SortSize > 64) {
131-
__syncthreads();
132-
}
133-
119+
__syncthreads();
120+
134121
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
135122
bitonicSwapKeys<Comparator, K>(
136123
keys[pos], valid[pos],
137124
keys[pos + stride], valid[pos + stride],
138125
false, comp);
139126
}
140127

141-
// Single warp per slice is completely synchronous
142-
if (Power2SortSize > 64) {
143-
__syncthreads();
144-
}
128+
__syncthreads();
129+
145130
}
146131

147132
// Sorts (key, value) pairs (in different tensors) in-place; i.e.,

0 commit comments

Comments
 (0)