@@ -44,8 +44,8 @@ namespace copy_as_contig
4444
4545template  <typename  T,
4646 typename  IndexerT,
47-  int  vec_sz = 4 ,
48-  int  n_vecs = 2 ,
47+  std:: uint32_t  vec_sz = 4u ,
48+  std:: uint32_t  n_vecs = 2u ,
4949 bool  enable_sg_loadstore = true >
5050class  CopyAsCContigFunctor 
5151{
@@ -66,53 +66,63 @@ class CopyAsCContigFunctor
6666
6767 void  operator ()(sycl::nd_item<1 > ndit) const 
6868 {
69+  static_assert (vec_sz > 0 );
70+  static_assert (n_vecs > 0 );
71+  static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
72+ 
73+  constexpr  std::uint8_t  elems_per_wi =
74+  static_cast <std::uint8_t >(vec_sz * n_vecs);
75+ 
6976 using  dpctl::tensor::type_utils::is_complex;
7077 if  constexpr  (!enable_sg_loadstore || is_complex<T>::value) {
71-  const  std::uint32_t  sgSize =
78+  const  std::uint16_t  sgSize =
7279 ndit.get_sub_group ().get_local_range ()[0 ];
7380 const  std::size_t  gid = ndit.get_global_linear_id ();
7481
75-  const  std::size_t  base =
76-  (gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize);
77-  for  (size_t  offset = base;
78-  offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
79-  offset += sgSize)
80-  {
82+  //  base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
83+  //  gid % sgSize == gid - (gid / sgSize) * sgSize
84+  const  std::size_t  elems_per_sg = sgSize * (elems_per_wi - 1 );
85+  const  std::size_t  base = (gid / sgSize) * elems_per_sg + gid;
86+  const  std::size_t  offset_max =
87+  std::min (nelems, base + sgSize * elems_per_wi);
88+ 
89+  for  (size_t  offset = base; offset < offset_max; offset += sgSize) {
8190 auto  src_offset = src_indexer (offset);
8291 dst_p[offset] = src_p[src_offset];
8392 }
8493 }
8594 else  {
8695 auto  sg = ndit.get_sub_group ();
87-  const  std::uint32_t  sgSize = sg.get_local_range ()[0 ];
88-  const  size_t  base = n_vecs * vec_sz * 
89-     (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
90-    sg.get_group_id ()[0 ] * sgSize);
96+  const  std::uint16_t  sgSize = sg.get_max_local_range ()[0 ];
97+  const  size_t  base =
98+  elems_per_wi *  (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
99+  sg.get_group_id ()[0 ] * sgSize);
91100
92-  if  (base + n_vecs * vec_sz  * sgSize < nelems) {
101+  if  (base + elems_per_wi  * sgSize < nelems) {
93102 sycl::vec<T, vec_sz> dst_vec;
94103
95104#pragma  unroll
96-  for  (std::uint32_t  it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
105+  for  (std::uint8_t  it = 0 ; it < elems_per_wi; it += vec_sz) {
106+  const  size_t  block_start_id = base + it * sgSize;
97107 auto  dst_multi_ptr = sycl::address_space_cast<
98108 sycl::access::address_space::global_space,
99-  sycl::access::decorated::yes>(
100-  &dst_p[base + it * sgSize]);
109+  sycl::access::decorated::yes>(&dst_p[block_start_id]);
101110
111+  const  size_t  elem_id0 = block_start_id + sg.get_local_id ();
102112#pragma  unroll
103-  for  (std::uint32_t  k = 0 ; k < vec_sz; k++) {
104-  ssize_t  src_offset =  src_indexer ( 
105-   base + (it + k) * sgSize + sg. get_local_id () );
113+  for  (std::uint8_t  k = 0 ; k < vec_sz; k++) {
114+  const   size_t  elem_id = elem_id0 + k * sgSize; 
115+  const   ssize_t  src_offset =  src_indexer (elem_id );
106116 dst_vec[k] = src_p[src_offset];
107117 }
108118 sg.store <vec_sz>(dst_multi_ptr, dst_vec);
109119 }
110120 }
111121 else  {
112-  for  ( size_t  k  = base +  sg.get_local_id ()[0 ]; k < nelems ;
113-   k += sgSize) 
114-  {
115-  ssize_t  src_offset = src_indexer (k);
122+  const   size_t  lane_id  = sg.get_local_id ()[0 ];
123+  const   size_t  k0 = base + lane_id; 
124+  for  ( size_t  k = k0; k < nelems; k += sgSize)  {
125+  const   ssize_t  src_offset = src_indexer (k);
116126 dst_p[k] = src_p[src_offset];
117127 }
118128 }
@@ -121,36 +131,23 @@ class CopyAsCContigFunctor
121131};
122132
123133template  <typename  T,
124-  typename  IndexT,
125-  int  vec_sz,
126-  int  n_vecs,
127-  bool  enable_sgload>
128- class  as_contig_krn ;
129- 
130- template  <typename  T>
131- sycl::event
132- as_c_contiguous_array_generic_impl (sycl::queue &exec_q,
133-  size_t  nelems,
134-  int  nd,
135-  const  ssize_t  *shape_and_strides,
136-  const  char  *src_p,
137-  char  *dst_p,
138-  const  std::vector<sycl::event> &depends)
134+  typename  IndexerT,
135+  std::uint32_t  vec_sz,
136+  std::uint32_t  n_vecs,
137+  bool  enable_sg_load,
138+  typename  KernelName>
139+ sycl::event submit_c_contiguous_copy (sycl::queue &exec_q,
140+  size_t  nelems,
141+  const  T *src,
142+  T *dst,
143+  const  IndexerT &src_indexer,
144+  const  std::vector<sycl::event> &depends)
139145{
140-  dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
141- 
142-  const  T *src_tp = reinterpret_cast <const  T *>(src_p);
143-  T *dst_tp = reinterpret_cast <T *>(dst_p);
144- 
145-  using  IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
146-  const  IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
146+  static_assert (vec_sz > 0 );
147+  static_assert (n_vecs > 0 );
148+  static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
147149
148150 constexpr  std::size_t  preferred_lws = 256 ;
149-  constexpr  std::uint32_t  n_vecs = 2 ;
150-  constexpr  std::uint32_t  vec_sz = 4 ;
151-  constexpr  bool  enable_sg_load = true ;
152-  using  KernelName =
153-  as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
154151
155152 const  auto  &kernel_id = sycl::get_kernel_id<KernelName>();
156153
@@ -167,9 +164,11 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
167164 const  std::size_t  lws =
168165 ((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
169166
170-  constexpr  std::uint32_t  nelems_per_wi = n_vecs * vec_sz;
171-  size_t  n_groups =
172-  (nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
167+  constexpr  std::uint8_t  nelems_per_wi = n_vecs * vec_sz;
168+ 
169+  const  size_t  nelems_per_group = nelems_per_wi * lws;
170+  const  size_t  n_groups =
171+  (nelems + nelems_per_group - 1 ) / (nelems_per_group);
173172
174173 sycl::event copy_ev = exec_q.submit ([&](sycl::handler &cgh) {
175174 cgh.depends_on (depends);
@@ -181,8 +180,62 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
181180 cgh.parallel_for <KernelName>(
182181 sycl::nd_range<1 >(gRange , lRange),
183182 CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs, enable_sg_load>(
184-  nelems, src_tp, dst_tp , src_indexer));
183+  nelems, src, dst , src_indexer));
185184 });
185+  return  copy_ev;
186+ }
187+ 
188+ template  <typename  T,
189+  typename  IndexT,
190+  std::uint32_t  vec_sz,
191+  std::uint32_t  n_vecs,
192+  bool  enable_sgload>
193+ class  as_contig_krn ;
194+ 
195+ template  <typename  T>
196+ sycl::event
197+ as_c_contiguous_array_generic_impl (sycl::queue &exec_q,
198+  size_t  nelems,
199+  int  nd,
200+  const  ssize_t  *shape_and_strides,
201+  const  char  *src_p,
202+  char  *dst_p,
203+  const  std::vector<sycl::event> &depends)
204+ {
205+  dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
206+ 
207+  const  T *src_tp = reinterpret_cast <const  T *>(src_p);
208+  T *dst_tp = reinterpret_cast <T *>(dst_p);
209+ 
210+  using  IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
211+  const  IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
212+ 
213+  constexpr  std::uint32_t  vec_sz = 4u ;
214+  constexpr  std::uint32_t  n_vecs = 2u ;
215+ 
216+  using  dpctl::tensor::kernels::alignment_utils::
217+  disabled_sg_loadstore_wrapper_krn;
218+  using  dpctl::tensor::kernels::alignment_utils::is_aligned;
219+  using  dpctl::tensor::kernels::alignment_utils::required_alignment;
220+ 
221+  sycl::event copy_ev;
222+  if  (is_aligned<required_alignment>(dst_p)) {
223+  constexpr  bool  enable_sg_load = true ;
224+  using  KernelName =
225+  as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
226+  copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
227+  enable_sg_load, KernelName>(
228+  exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
229+  }
230+  else  {
231+  constexpr  bool  disable_sg_load = false ;
232+  using  InnerKernelName =
233+  as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
234+  using  KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
235+  copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
236+  disable_sg_load, KernelName>(
237+  exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
238+  }
186239
187240 return  copy_ev;
188241}
0 commit comments