File tree Expand file tree Collapse file tree 1 file changed +9
-10
lines changed
dpctl/tensor/libtensor/include/kernels Expand file tree Collapse file tree 1 file changed +9
-10
lines changed Original file line number Diff line number Diff line change @@ -227,23 +227,22 @@ inclusive_scan_base_step(sycl::queue &exec_q,
227227
228228#pragma unroll
229229 for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
230+ const size_t i_m_wi = i + m_wi;
230231 if constexpr (!include_initial) {
231232 local_iscan[m_wi] =
232- (i + m_wi < acc_nelems)
233- ? transformer (
234- input[inp_iter_offset +
235- inp_indexer (s0 + s1 * (i + m_wi))])
233+ (i_m_wi < acc_nelems)
234+ ? transformer (input[inp_iter_offset +
235+ inp_indexer (s0 + s1 * i_m_wi)])
236236 : identity;
237237 }
238238 else {
239239 // shift input to the left by a single element relative to
240240 // output
241241 local_iscan[m_wi] =
242- (i + m_wi < acc_nelems && i + m_wi > 0 )
242+ (i_m_wi < acc_nelems && i_m_wi > 0 )
243243 ? transformer (
244244 input[inp_iter_offset +
245- inp_indexer ((s0 + s1 * (i + m_wi)) -
246- 1 )])
245+ inp_indexer ((s0 + s1 * i_m_wi) - 1 )])
247246 : identity;
248247 }
249248 }
@@ -280,9 +279,9 @@ inclusive_scan_base_step(sycl::queue &exec_q,
280279 local_iscan[m_wi] = scan_op (local_iscan[m_wi], addand);
281280 }
282281
283- for ( nwiT m_wi = 0 ; (m_wi < n_wi) && (i + m_wi < acc_nelems);
284- ++m_wi)
285- {
282+ const nwiT m_max =
283+ std::min<nwiT>(n_wi, std::max (i, acc_nelems) - i);
284+ for (nwiT m_wi = 0 ; m_wi < m_max; ++m_wi) {
286285 output[out_iter_offset + out_indexer (i + m_wi)] =
287286 local_iscan[m_wi];
288287 }
You can’t perform that action at this time.
0 commit comments