Skip to content

Commit 69f3311

Browse files
DiamonDinoiaserge-sans-paille
authored andcommitted
This improves performance of AVX512 swizzle.
Minor: It fixes swizzles with runtime mask that did not allow duplicates
1 parent 2b0366a commit 69f3311

File tree

3 files changed

+250
-165
lines changed

3 files changed

+250
-165
lines changed

include/xsimd/arch/xsimd_avx512dq.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,28 @@ namespace xsimd
188188
return reduce_add(batch<float, avx2>(res1), avx2 {});
189189
}
190190

191+
// swizzle constant mask
192+
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7,
193+
uint32_t V8, uint32_t V9, uint32_t V10, uint32_t V11, uint32_t V12, uint32_t V13, uint32_t V14, uint32_t V15>
194+
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self,
195+
batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask,
196+
requires_arch<avx512dq>) noexcept
197+
{
198+
constexpr bool dup_lo = detail::is_dup_lo(mask);
199+
constexpr bool dup_hi = detail::is_dup_hi(mask);
200+
201+
XSIMD_IF_CONSTEXPR(dup_lo || dup_hi)
202+
{
203+
const batch<float, avx2> half = _mm512_extractf32x8_ps(self, dup_lo ? 0 : 1);
204+
constexpr typename std::conditional<dup_lo, batch_constant<uint32_t, avx2, V0 % 8, V1 % 8, V2 % 8, V3 % 8, V4 % 8, V5 % 8, V6 % 8, V7 % 8>,
205+
batch_constant<uint32_t, avx2, V8 % 8, V9 % 8, V10 % 8, V11 % 8, V12 % 8, V13 % 8, V14 % 8, V15 % 8>>::type half_mask {};
206+
auto permuted = swizzle(half, half_mask, avx2 {});
207+
// merge the two slices into an AVX512F register:
208+
return _mm512_broadcast_f32x8(permuted); // duplicates the 256-bit perm into both halves
209+
}
210+
return swizzle(self, mask, avx512f {});
211+
}
212+
191213
// convert
192214
namespace detail
193215
{

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,16 +2124,53 @@ namespace xsimd
21242124
return bitwise_cast<int32_t>(swizzle(bitwise_cast<uint32_t>(self), mask, avx512f {}));
21252125
}
21262126

2127-
// swizzle (constant version)
2128-
template <class A, uint32_t... Vs>
2129-
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
2127+
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7,
2128+
uint32_t V8, uint32_t V9, uint32_t V10, uint32_t V11, uint32_t V12, uint32_t V13, uint32_t V14, uint32_t V15>
2129+
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self,
2130+
batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask,
2131+
requires_arch<avx512f>) noexcept
21302132
{
2133+
XSIMD_IF_CONSTEXPR(detail::is_identity(mask))
2134+
{
2135+
return self;
2136+
}
2137+
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
2138+
{
2139+
constexpr int imm0 = detail::mod_shuffle(V0, V1, V2, V3);
2140+
constexpr int imm1 = detail::mod_shuffle(V4, V5, V6, V7);
2141+
constexpr int imm2 = detail::mod_shuffle(V8, V9, V10, V11);
2142+
constexpr int imm3 = detail::mod_shuffle(V12, V13, V14, V15);
2143+
XSIMD_IF_CONSTEXPR(imm0 == imm1 && imm0 == imm2 && imm0 == imm3)
2144+
{
2145+
return _mm512_permute_ps(self, imm0);
2146+
}
2147+
}
21312148
return swizzle(self, mask.as_batch(), avx512f {});
21322149
}
2133-
2134-
template <class A, uint64_t... Vs>
2135-
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
2150+
template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3, uint64_t V4, uint64_t V5, uint64_t V6, uint64_t V7>
2151+
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self,
2152+
batch_constant<uint64_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask,
2153+
requires_arch<avx512f>) noexcept
21362154
{
2155+
XSIMD_IF_CONSTEXPR(detail::is_identity(mask))
2156+
{
2157+
return self;
2158+
}
2159+
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
2160+
{
2161+
constexpr auto imm = ((V0 & 1) << 0) | ((V1 & 1) << 1) | ((V2 & 1) << 2) | ((V3 & 1) << 3) | ((V4 & 1) << 4) | ((V5 & 1) << 5) | ((V6 & 1) << 6) | ((V7 & 1) << 7);
2162+
return _mm512_permute_pd(self, imm);
2163+
}
2164+
constexpr bool dup_lo = detail::is_dup_lo(mask);
2165+
constexpr bool dup_hi = detail::is_dup_hi(mask);
2166+
XSIMD_IF_CONSTEXPR(dup_lo || dup_hi)
2167+
{
2168+
const batch<double, avx2> half = _mm512_extractf64x4_pd(self, dup_lo ? 0 : 1);
2169+
constexpr typename std::conditional<dup_lo, batch_constant<uint64_t, avx2, V0 % 4, V1 % 4, V2 % 4, V3 % 4>,
2170+
batch_constant<uint64_t, avx2, V4 % 4, V5 % 4, V6 % 4, V7 % 4>>::type half_mask {};
2171+
return _mm512_broadcast_f64x4(swizzle(half, half_mask, avx2 {}));
2172+
}
2173+
// General case
21372174
return swizzle(self, mask.as_batch(), avx512f {});
21382175
}
21392176

0 commit comments

Comments
 (0)