Skip to content

Commit 75290e1

Browse files
authored
Merge pull request #4764 from usamoi/pshufb
add avx512 pshufb
2 parents 6d58cbd + fc16679 commit 75290e1

File tree

5 files changed

+81
-46
lines changed

5 files changed

+81
-46
lines changed

src/tools/miri/src/shims/x86/avx2.rs

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
66

77
use super::{
88
ShiftOp, horizontal_bin_op, mpsadbw, packssdw, packsswb, packusdw, packuswb, permute, pmaddbw,
9-
pmulhrsw, psadbw, psign, shift_simd_by_scalar,
9+
pmulhrsw, psadbw, pshufb, psign, shift_simd_by_scalar,
1010
};
1111
use crate::*;
1212

@@ -189,28 +189,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
189189
let [left, right] =
190190
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
191191

192-
let (left, left_len) = this.project_to_simd(left)?;
193-
let (right, right_len) = this.project_to_simd(right)?;
194-
let (dest, dest_len) = this.project_to_simd(dest)?;
195-
196-
assert_eq!(dest_len, left_len);
197-
assert_eq!(dest_len, right_len);
198-
199-
for i in 0..dest_len {
200-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
201-
let dest = this.project_index(&dest, i)?;
202-
203-
let res = if right & 0x80 == 0 {
204-
// Shuffle each 128-bit (16-byte) block independently.
205-
let j = u64::from(right % 16).strict_add(i & !15);
206-
this.read_scalar(&this.project_index(&left, j)?)?
207-
} else {
208-
// If the highest bit in `right` is 1, write zero.
209-
Scalar::from_u8(0)
210-
};
211-
212-
this.write_scalar(res, &dest)?;
213-
}
192+
pshufb(this, left, right, dest)?;
214193
}
215194
// Used to implement the _mm256_sign_epi{8,16,32} functions.
216195
// Negates elements from `left` when the corresponding element in

src/tools/miri/src/shims/x86/avx512.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
33
use rustc_span::Symbol;
44
use rustc_target::callconv::FnAbi;
55

6-
use super::{permute, pmaddbw, psadbw};
6+
use super::{permute, pmaddbw, psadbw, pshufb};
77
use crate::*;
88

99
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -102,6 +102,13 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
102102

103103
permute(this, left, right, dest)?;
104104
}
105+
// Used to implement the _mm512_shuffle_epi8 intrinsic.
106+
"pshuf.b.512" => {
107+
let [left, right] =
108+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
109+
110+
pshufb(this, left, right, dest)?;
111+
}
105112
_ => return interp_ok(EmulateItemResult::NotSupported),
106113
}
107114
interp_ok(EmulateItemResult::NeedsReturn)

src/tools/miri/src/shims/x86/mod.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,51 @@ fn pclmulqdq<'tcx>(
11551155
interp_ok(())
11561156
}
11571157

1158+
/// Shuffles bytes from `left` using `right` as pattern. Each 16-byte block is shuffled independently.
1159+
///
1160+
/// `left` and `right` are both vectors of type `len` x i8.
1161+
///
1162+
/// If the highest bit of a byte in `right` is not set, the corresponding byte in `dest` is taken
1163+
/// from the current 16-byte block of `left` at the position indicated by the lowest 4 bits of this
1164+
/// byte in `right`. If the highest bit of a byte in `right` is set, the corresponding byte in
1165+
/// `dest` is set to `0`.
1166+
///
1167+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_shuffle_epi8>
1168+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_shuffle_epi8>
1169+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_shuffle_epi8>
1170+
fn pshufb<'tcx>(
1171+
ecx: &mut crate::MiriInterpCx<'tcx>,
1172+
left: &OpTy<'tcx>,
1173+
right: &OpTy<'tcx>,
1174+
dest: &MPlaceTy<'tcx>,
1175+
) -> InterpResult<'tcx, ()> {
1176+
let (left, left_len) = ecx.project_to_simd(left)?;
1177+
let (right, right_len) = ecx.project_to_simd(right)?;
1178+
let (dest, dest_len) = ecx.project_to_simd(dest)?;
1179+
1180+
assert_eq!(dest_len, left_len);
1181+
assert_eq!(dest_len, right_len);
1182+
1183+
for i in 0..dest_len {
1184+
let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_u8()?;
1185+
let dest = ecx.project_index(&dest, i)?;
1186+
1187+
let res = if right & 0x80 == 0 {
1188+
// Shuffle each 128-bit (16-byte) block independently.
1189+
let block_offset = i & !15; // round down to previous multiple of 16
1190+
let j = block_offset.strict_add((right % 16).into());
1191+
ecx.read_scalar(&ecx.project_index(&left, j)?)?
1192+
} else {
1193+
// If the highest bit in `right` is 1, write zero.
1194+
Scalar::from_u8(0)
1195+
};
1196+
1197+
ecx.write_scalar(res, &dest)?;
1198+
}
1199+
1200+
interp_ok(())
1201+
}
1202+
11581203
/// Packs two N-bit integer vectors to a single N/2-bit integers.
11591204
///
11601205
/// The conversion from N-bit to N/2-bit should be provided by `f`.

src/tools/miri/src/shims/x86/ssse3.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
44
use rustc_span::Symbol;
55
use rustc_target::callconv::FnAbi;
66

7-
use super::{horizontal_bin_op, pmaddbw, pmulhrsw, psign};
7+
use super::{horizontal_bin_op, pmaddbw, pmulhrsw, pshufb, psign};
88
use crate::*;
99

1010
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -29,27 +29,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
2929
let [left, right] =
3030
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
3131

32-
let (left, left_len) = this.project_to_simd(left)?;
33-
let (right, right_len) = this.project_to_simd(right)?;
34-
let (dest, dest_len) = this.project_to_simd(dest)?;
35-
36-
assert_eq!(dest_len, left_len);
37-
assert_eq!(dest_len, right_len);
38-
39-
for i in 0..dest_len {
40-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
41-
let dest = this.project_index(&dest, i)?;
42-
43-
let res = if right & 0x80 == 0 {
44-
let j = right % 16; // index wraps around
45-
this.read_scalar(&this.project_index(&left, j.into())?)?
46-
} else {
47-
// If the highest bit in `right` is 1, write zero.
48-
Scalar::from_u8(0)
49-
};
50-
51-
this.write_scalar(res, &dest)?;
52-
}
32+
pshufb(this, left, right, dest)?;
5333
}
5434
// Used to implement the _mm_h{adds,subs}_epi16 functions.
5535
// Horizontally add / subtract with saturation adjacent 16-bit

src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,30 @@ unsafe fn test_avx512() {
143143
assert_eq_m512i(r, e);
144144
}
145145
test_mm512_permutexvar_epi32();
146+
147+
#[target_feature(enable = "avx512bw")]
148+
unsafe fn test_mm512_shuffle_epi8() {
149+
#[rustfmt::skip]
150+
let a = _mm512_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
151+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
152+
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
153+
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63);
154+
#[rustfmt::skip]
155+
let b = _mm512_set_epi8(-1, 127, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
156+
-1, 127, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
157+
-1, 127, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
158+
-1, 127, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
159+
let r = _mm512_shuffle_epi8(a, b);
160+
// `_mm512_set_epi8` sets the bytes in inverse order (?!?), so the indices in `b` seem to
161+
// index from the *back* of the corresponding 16-byte block in `a`.
162+
#[rustfmt::skip]
163+
let e = _mm512_set_epi8(0, 0, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
164+
0, 16, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
165+
0, 32, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46,
166+
0, 48, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62);
167+
assert_eq_m512i(r, e);
168+
}
169+
test_mm512_shuffle_epi8();
146170
}
147171

148172
// Some of the constants in the tests below are just bit patterns. They should not

0 commit comments

Comments
 (0)