@@ -65,12 +65,9 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
6565
6666#pragma unroll
6767 for (unsigned int stride = size / 2 ; stride > 0 ; stride /= 2 ) {
68-
69- // Single warp per slice is completely synchronous
70- if (Power2SortSize > 64 ) {
71- __syncthreads ();
72- }
73-
68+
69+ __syncthreads ();
70+
7471 unsigned int pos = 2 * threadIdx .x - (threadIdx .x & (stride - 1 ));
7572 bitonicSwap<Comparator, K, V>(
7673 keys[pos], values[pos], valid[pos],
@@ -81,22 +78,18 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
8178
8279#pragma unroll
8380 for (unsigned int stride = Power2SortSize / 2 ; stride > 0 ; stride /= 2 ) {
84- // Single warp per slice is completely synchronous
85- if (Power2SortSize > 64 ) {
86- __syncthreads ();
87- }
88-
81+
82+ __syncthreads ();
83+
8984 unsigned int pos = 2 * threadIdx .x - (threadIdx .x & (stride - 1 ));
9085 bitonicSwap<Comparator, K, V>(
9186 keys[pos], values[pos], valid[pos],
9287 keys[pos + stride], values[pos + stride], valid[pos + stride],
9388 false , comp);
9489 }
9590
96- // Single warp per slice is completely synchronous
97- if (Power2SortSize > 64 ) {
98- __syncthreads ();
99- }
91+ __syncthreads ();
92+
10093}
10194
10295template <typename Comparator, typename K,
@@ -111,11 +104,8 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
111104#pragma unroll
112105 for (unsigned int stride = size / 2 ; stride > 0 ; stride /= 2 ) {
113106
114- // Single warp per slice is completely synchronous
115- if (Power2SortSize > 64 ) {
116- __syncthreads ();
117- }
118-
107+ __syncthreads ();
108+
119109 unsigned int pos = 2 * threadIdx .x - (threadIdx .x & (stride - 1 ));
120110 bitonicSwapKeys<Comparator, K>(
121111 keys[pos], valid[pos],
@@ -126,22 +116,17 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
126116
127117#pragma unroll
128118 for (unsigned int stride = Power2SortSize / 2 ; stride > 0 ; stride /= 2 ) {
129- // Single warp per slice is completely synchronous
130- if (Power2SortSize > 64 ) {
131- __syncthreads ();
132- }
133-
119+ __syncthreads ();
120+
134121 unsigned int pos = 2 * threadIdx .x - (threadIdx .x & (stride - 1 ));
135122 bitonicSwapKeys<Comparator, K>(
136123 keys[pos], valid[pos],
137124 keys[pos + stride], valid[pos + stride],
138125 false , comp);
139126 }
140127
141- // Single warp per slice is completely synchronous
142- if (Power2SortSize > 64 ) {
143- __syncthreads ();
144- }
128+ __syncthreads ();
129+
145130}
146131
147132// Sorts (key, value) pairs (in different tensors) in-place; i.e.,
0 commit comments