@@ -150,6 +150,112 @@ static inline JBLAS_CODE dequant_kblock_s8_f32(int8_t* srcptr, float* dstptr, in
150150 kblock, NPad);
151151}
152152
153+ static inline JBLAS_CODE dequant_s32_fp32 (const int32_t * srcptr, const int srcstep, float * dstptr, const int dststep,
154+ const int row, const int col, const float * scaleA, const int ldsa,
155+ const float * scaleB) {
156+ int col8 = utils::padto_le (col, 8 );
157+ for (int irow = 0 ; irow < row; irow++) {
158+ auto scale = scaleA[irow * ldsa];
159+ auto valpha = _mm256_set1_ps (scale);
160+ int icol = 0 ;
161+ for (; icol < col8; icol += 8 ) {
162+ auto vwscale = _mm256_loadu_ps (scaleB + icol);
163+ auto vscale = _mm256_mul_ps (valpha, vwscale);
164+ auto vsrcd = _mm256_loadu_si256 ((__m256i*)(srcptr + irow * srcstep + icol));
165+ auto vsrc = _mm256_cvtepi32_ps (vsrcd);
166+ vsrc = _mm256_mul_ps (vsrc, vscale);
167+ _mm256_storeu_ps (dstptr + irow * dststep + icol, vsrc);
168+ }
169+ for (; icol < col; icol += 1 ) {
170+ dstptr[irow * dststep + icol] = scale * scaleB[icol] * srcptr[irow * srcstep + icol];
171+ }
172+ }
173+ return JblasSuccess;
174+ }
175+
176+ static inline JBLAS_CODE remove_act_zeropoint_bias (float * accptr, int ldacc, int row, int col, uint8_t * zps,
177+ float * scales, int lds, const float * reduce) {
178+ int constexpr VLen = 8 ;
179+ auto col8 = utils::padto_le (col, VLen);
180+ for (int i = 0 ; i < row; i++) {
181+ auto zpf = float (zps[i * lds]) * scales[i * lds];
182+ int j = 0 ;
183+ auto vzp = _mm256_set1_ps (-zpf);
184+ for (; j < col8; j += VLen) {
185+ auto vreduce = _mm256_loadu_ps (reduce + j);
186+ auto vacc = _mm256_loadu_ps (&accptr[i * ldacc + j]);
187+ vacc = _mm256_fmadd_ps (vzp, vreduce, vacc);
188+ _mm256_storeu_ps (&accptr[i * ldacc + j], vacc);
189+ }
190+ if (j < col) {
191+ for (; j < col; j++) {
192+ accptr[i * ldacc + j] -= zpf * reduce[j];
193+ }
194+ }
195+ }
196+ return JblasSuccess;
197+ }
198+
199+ static inline JBLAS_CODE remove_wei_zeropoint_bias (float * accptr, int ldacc, int row, int col, int8_t * zps,
200+ float * scales, int lds, const float * reduce) {
201+ int constexpr VLen = 8 ;
202+ auto col8 = utils::padto_le (col, VLen);
203+ const int32_t mask[] = {-1 , -1 , 0 , 0 };
204+ for (int i = 0 ; i < row; i++) {
205+ auto vreduce = _mm256_set1_ps (-reduce[i * lds]);
206+ int j = 0 ;
207+ for (; j < col8; j += VLen) {
208+ auto vzp_s32 = _mm256_cvtepi8_epi32 (_mm_maskload_epi32 ((const int *)(zps + j), _mm_loadu_si128 ((__m128i*)mask)));
209+ auto vzp_f32 = _mm256_cvtepi32_ps (vzp_s32);
210+ auto vzp = _mm256_mul_ps (vzp_f32, _mm256_loadu_ps (scales + j));
211+ auto vacc = _mm256_loadu_ps (&accptr[i * ldacc + j]);
212+ vacc = _mm256_fmadd_ps (vzp, vreduce, vacc);
213+ _mm256_storeu_ps (&accptr[i * ldacc + j], vacc);
214+ }
215+ if (j < col) {
216+ for (; j < col8; j++) {
217+ accptr[i * ldacc + j] -= float (zps[j]) * scales[j] * reduce[i * lds];
218+ }
219+ }
220+ }
221+ return JblasSuccess;
222+ }
223+
224+ static inline JBLAS_CODE remove_zeropoint_bias (float * accptr, int ldacc, int row, int col, uint8_t * zpa, int8_t * zpb,
225+ float * scalea, float * scaleb, int lds, int k, const float * reducea,
226+ const float * reduceb) {
227+ int constexpr VLen = 8 ;
228+ auto col8 = utils::padto_le (col, VLen);
229+ auto vk = _mm256_set1_ps ((float )(k));
230+ const int32_t mask[] = {-1 , -1 , 0 , 0 };
231+ for (int i = 0 ; i < row; i++) {
232+ auto vreducea = _mm256_set1_ps (-reducea[i * lds]);
233+ auto zpaf = float (zpa[i * lds]) * scalea[i * lds];
234+ auto vzpa = _mm256_set1_ps (-zpaf);
235+ int j = 0 ;
236+ for (; j < col8; j += VLen) {
237+ auto vzp_s32 = _mm256_cvtepi8_epi32 (_mm_maskload_epi32 ((const int *)(zpb + j), _mm_loadu_si128 ((__m128i*)mask)));
238+ auto vzp_f32 = _mm256_cvtepi32_ps (vzp_s32);
239+ auto vzpb = _mm256_mul_ps (vzp_f32, _mm256_loadu_ps (scaleb + j));
240+ auto vreduceb = _mm256_loadu_ps (reduceb + j);
241+ auto vacc = _mm256_loadu_ps (&accptr[i * ldacc + j]);
242+ vacc = _mm256_fmadd_ps (vzpa, vreduceb, vacc);
243+ vacc = _mm256_fmadd_ps (vzpb, vreducea, vacc);
244+ vzpb = _mm256_mul_ps (vzpb, vk);
245+ vacc = _mm256_fmadd_ps (vzpa, vzpb, vacc);
246+ _mm256_storeu_ps (&accptr[i * ldacc + j], vacc);
247+ }
248+ if (j < col) {
249+ for (; j < col8; j++) {
250+ accptr[i * ldacc + j] -= float (zpb[j]) * scaleb[j] * reducea[i * lds];
251+ accptr[i * ldacc + j] -= zpaf * reduceb[j];
252+ accptr[i * ldacc + j] -= zpaf * float (zpb[j]) * scaleb[j] * k;
253+ }
254+ }
255+ }
256+ return JblasSuccess;
257+ }
258+
153259template <JBLAS_SIGN_INT_TYPE S4_T>
154260static inline JBLAS_CODE decompress_s4_s8 (utils::int4x2* srcptr, int8_t * dstptr, int row, int col, int ld_src,
155261 int ld_dst) {
0 commit comments