@@ -28,58 +28,144 @@ class work {};
2828
2929template <>
3030class work <float > {
31- using psd_data_t = aligned::aligned_mem<float >;
31+ using data_type = aligned::aligned_mem<float >;
32+ static constexpr auto STRIDE_256 = std::size_t {256u / 8u / sizeof (float )};
33+ static constexpr auto STRIDE_512 = std::size_t {512u / 8u / sizeof (float )};
3234public:
3335 explicit work (float alpha) {
34- m_alpha_vec = _mm512_set1_ps (alpha);
35- m_one_minus_alpha_vec = _mm512_set1_ps (1 - alpha);
36+ // Initialize the function pointer based on CPU features
37+ if (__builtin_cpu_supports (" avx512f" )) {
38+ init_avx512 (alpha);
39+ } else if (__builtin_cpu_supports (" avx2" ) && __builtin_cpu_supports (" fma" )) {
40+ init_avx2 (alpha);
41+ }
42+ }
43+
44+ auto process (data_type* curr_psd, data_type* prev_psd) const -> void {
45+ (this ->*process_func)(curr_psd, prev_psd);
46+ }
47+
48+ private:
49+ [[gnu::target(" avx2,fma" )]]
50+ auto init_avx2 (double alpha) -> void {
51+ process_func = &work::process_avx2;
52+ m_alpha_vec_256 = _mm256_set1_ps (alpha);
53+ m_one_minus_alpha_vec_256 = _mm256_set1_ps (1 - alpha);
54+ }
55+
56+ [[gnu::target(" avx512f" )]]
57+ auto init_avx512 (double alpha) -> void {
58+ process_func = &work::process_avx512;
59+ m_alpha_vec_512 = _mm512_set1_ps (alpha);
60+ m_one_minus_alpha_vec_512 = _mm512_set1_ps (1 - alpha);
3661 }
3762
38- auto process (psd_data_t * curr_psd, psd_data_t * prev_psd) const -> void {
39- for (auto i=0u ; i < curr_psd->size (); i += 16 ) {
63+ [[gnu::target(" avx2,fma" )]]
64+ auto process_avx2 (data_type* curr_psd, data_type* prev_psd) const -> void {
65+ for (auto i=0u ; i < curr_psd->size (); i += STRIDE_256) {
66+ // Load data
67+ auto curr_data = _mm256_load_ps (curr_psd->data () + i);
68+ auto prev_data = _mm256_load_ps (prev_psd->data () + i);
69+ // Multiply current data by alpha
70+ curr_data = _mm256_mul_ps (curr_data, m_alpha_vec_256);
71+ // Multiply prev by (1-alpha) and add to current
72+ curr_data = _mm256_fmadd_ps (prev_data, m_one_minus_alpha_vec_256, curr_data);
73+ // Store result into psd
74+ _mm256_store_ps (curr_psd->data () + i, curr_data);
75+ }
76+ }
77+
78+ [[gnu::target(" avx512f" )]]
79+ auto process_avx512 (data_type* curr_psd, data_type* prev_psd) const -> void {
80+ for (auto i=0u ; i < curr_psd->size (); i += STRIDE_512) {
4081 // Load data
4182 auto curr_data = _mm512_load_ps (curr_psd->data () + i);
4283 auto prev_data = _mm512_load_ps (prev_psd->data () + i);
4384 // Multiply current data by alpha
44- curr_data = _mm512_mul_ps (curr_data, m_alpha_vec );
85+ curr_data = _mm512_mul_ps (curr_data, m_alpha_vec_512 );
4586 // Multiply prev by (1-alpha) and add to current
46- curr_data = _mm512_fmadd_ps (prev_data, m_one_minus_alpha_vec , curr_data);
87+ curr_data = _mm512_fmadd_ps (prev_data, m_one_minus_alpha_vec_512 , curr_data);
4788 // Store result into psd
4889 _mm512_store_ps (curr_psd->data () + i, curr_data);
4990 }
5091 }
5192
52- private:
53- __m512 m_alpha_vec;
54- __m512 m_one_minus_alpha_vec;
93+ auto (work::*process_func)(data_type*, data_type*) const -> void;
94+ __m256 m_alpha_vec_256;
95+ __m512 m_alpha_vec_512;
96+ __m256 m_one_minus_alpha_vec_256;
97+ __m512 m_one_minus_alpha_vec_512;
5598
5699}; // class work<float>
57100
58101template <>
59102class work <double > {
60- using psd_data_t = aligned::aligned_mem<double >;
103+ using data_type = aligned::aligned_mem<double >;
104+ static constexpr auto STRIDE_256 = std::size_t {256u / 8u / sizeof (double )};
105+ static constexpr auto STRIDE_512 = std::size_t {512u / 8u / sizeof (double )};
61106public:
62107 explicit work (double alpha) {
63- m_alpha_vec = _mm512_set1_pd (alpha);
64- m_one_minus_alpha_vec = _mm512_set1_pd (1 - alpha);
108+ // Initialize the function pointer based on CPU features
109+ if (__builtin_cpu_supports (" avx512f" )) {
110+ init_avx512 (alpha);
111+ } else if (__builtin_cpu_supports (" avx2" ) && __builtin_cpu_supports (" fma" )) {
112+ init_avx2 (alpha);
113+ }
114+ }
115+
116+ auto process (data_type* curr_psd, data_type* prev_psd) const -> void {
117+ (this ->*process_func)(curr_psd, prev_psd);
118+ }
119+
120+ private:
121+ [[gnu::target(" avx2,fma" )]]
122+ auto init_avx2 (double alpha) -> void {
123+ process_func = &work::process_avx2;
124+ m_alpha_vec_256 = _mm256_set1_pd (alpha);
125+ m_one_minus_alpha_vec_256 = _mm256_set1_pd (1 - alpha);
126+ }
127+
128+ [[gnu::target(" avx512f" )]]
129+ auto init_avx512 (double alpha) -> void {
130+ process_func = &work::process_avx512;
131+ m_alpha_vec_512 = _mm512_set1_pd (alpha);
132+ m_one_minus_alpha_vec_512 = _mm512_set1_pd (1 - alpha);
65133 }
66134
67- auto process (psd_data_t * curr_psd, psd_data_t * prev_psd) const -> void {
68- for (auto i=0u ; i < curr_psd->size (); i += 8 ) {
135+ [[gnu::target(" avx2,fma" )]]
136+ auto process_avx2 (data_type* curr_psd, data_type* prev_psd) const -> void {
137+ for (auto i=0u ; i < curr_psd->size (); i += STRIDE_256) {
138+ // Load data
139+ auto curr_data = _mm256_load_pd (curr_psd->data () + i);
140+ auto prev_data = _mm256_load_pd (prev_psd->data () + i);
141+ // Multiply current data by alpha
142+ curr_data = _mm256_mul_pd (curr_data, m_alpha_vec_256);
143+ // Multiply prev by (1-alpha) and add to current
144+ curr_data = _mm256_fmadd_pd (prev_data, m_one_minus_alpha_vec_256, curr_data);
145+ // Store result into psd
146+ _mm256_store_pd (curr_psd->data () + i, curr_data);
147+ }
148+ }
149+
150+ [[gnu::target(" avx512f" )]]
151+ auto process_avx512 (data_type* curr_psd, data_type* prev_psd) const -> void {
152+ for (auto i=0u ; i < curr_psd->size (); i += STRIDE_512) {
69153 // Load data
70154 auto curr_data = _mm512_load_pd (curr_psd->data () + i);
71155 auto prev_data = _mm512_load_pd (prev_psd->data () + i);
72156 // Multiply current data by alpha
73- curr_data = _mm512_mul_pd (curr_data, m_alpha_vec );
157+ curr_data = _mm512_mul_pd (curr_data, m_alpha_vec_512 );
74158 // Multiply prev by (1-alpha) and add to current
75- curr_data = _mm512_fmadd_pd (prev_data, m_one_minus_alpha_vec , curr_data);
159+ curr_data = _mm512_fmadd_pd (prev_data, m_one_minus_alpha_vec_512 , curr_data);
76160 // Store result into psd
77161 _mm512_store_pd (curr_psd->data () + i, curr_data);
78162 }
79163 }
80164
81- private:
82- __m512d m_alpha_vec;
83- __m512d m_one_minus_alpha_vec;
165+ auto (work::*process_func)(data_type*, data_type*) const -> void;
166+ __m256d m_alpha_vec_256;
167+ __m512d m_alpha_vec_512;
168+ __m256d m_one_minus_alpha_vec_256;
169+ __m512d m_one_minus_alpha_vec_512;
84170
85171}; // class work<double>
0 commit comments