Skip to content

Commit c33d51b

Browse files
authored
Rollup merge of rust-lang#147355 - sayantn:masked-loads, r=RalfJung,bjorn3
Add alignment parameter to `simd_masked_{load,store}` This PR adds an alignment parameter in `simd_masked_load` and `simd_masked_store`, in the form of a const-generic enum `core::intrinsics::simd::SimdAlign`. This represents the alignment of the `ptr` argument in these intrinsics as follows - `SimdAlign::Unaligned` - `ptr` is unaligned/1-byte aligned - `SimdAlign::Element` - `ptr` is aligned to the element type of the SIMD vector (default behavior in the old signature) - `SimdAlign::Vector` - `ptr` is aligned to the SIMD vector type The main motive for this is stdarch - most vector loads are either fully aligned (to the vector size) or unaligned (byte-aligned), so the previous signature doesn't cut it. Now, stdarch will mostly use `SimdAlign::Unaligned` and `SimdAlign::Vector`, whereas portable-simd will use `SimdAlign::Element`. - [x] `cg_llvm` - [x] `cg_clif` - [x] `miri`/`const_eval` ## Alternatives Using a const-generic/"const" `u32` parameter as alignment (and we error during codegen if this argument is not a power of two). This, although more flexible than this, has a few drawbacks - If we use an const-generic argument, then portable-simd somehow needs to pass `align_of::<T>()` as the alignment, which isn't possible without GCE - "const" function parameters are just an ugly hack, and a pain to deal with in non-LLVM backends We can remedy the problem with the const-generic `u32` parameter by adding a special rule for the element alignment case (e.g. `0` can mean "use the alignment of the element type), but I feel like this is not as expressive as the enum approach, although I am open to suggestions cc `@workingjubilee` `@RalfJung` `@BoxyUwU`
2 parents f15a7f3 + 21fb801 commit c33d51b

27 files changed

+677
-115
lines changed

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use cranelift_codegen::ir::immediates::Offset32;
44
use rustc_abi::Endian;
5+
use rustc_middle::ty::SimdAlign;
56

67
use super::*;
78
use crate::prelude::*;
@@ -960,6 +961,15 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
960961
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
961962
let ptr_val = ptr.load_scalar(fx);
962963

964+
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
965+
.unwrap_leaf()
966+
.to_simd_alignment();
967+
968+
let memflags = match alignment {
969+
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
970+
_ => MemFlags::trusted(),
971+
};
972+
963973
for lane_idx in 0..val_lane_count {
964974
let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
965975
let mask_lane = mask.value_lane(fx, lane_idx).load_scalar(fx);
@@ -972,7 +982,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
972982

973983
fx.bcx.switch_to_block(if_enabled);
974984
let offset = lane_idx as i32 * lane_clif_ty.bytes() as i32;
975-
fx.bcx.ins().store(MemFlags::trusted(), val_lane, ptr_val, Offset32::new(offset));
985+
fx.bcx.ins().store(memflags, val_lane, ptr_val, Offset32::new(offset));
976986
fx.bcx.ins().jump(next, &[]);
977987

978988
fx.bcx.seal_block(next);
@@ -996,6 +1006,15 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
9961006
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
9971007
let ret_lane_layout = fx.layout_of(ret_lane_ty);
9981008

1009+
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1010+
.unwrap_leaf()
1011+
.to_simd_alignment();
1012+
1013+
let memflags = match alignment {
1014+
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
1015+
_ => MemFlags::trusted(),
1016+
};
1017+
9991018
for lane_idx in 0..ptr_lane_count {
10001019
let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
10011020
let ptr_lane = ptr.value_lane(fx, lane_idx).load_scalar(fx);
@@ -1011,7 +1030,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10111030
fx.bcx.seal_block(if_disabled);
10121031

10131032
fx.bcx.switch_to_block(if_enabled);
1014-
let res = fx.bcx.ins().load(lane_clif_ty, MemFlags::trusted(), ptr_lane, 0);
1033+
let res = fx.bcx.ins().load(lane_clif_ty, memflags, ptr_lane, 0);
10151034
fx.bcx.ins().jump(next, &[res.into()]);
10161035

10171036
fx.bcx.switch_to_block(if_disabled);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1313
use rustc_hir::{self as hir};
1414
use rustc_middle::mir::BinOp;
1515
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
16-
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
16+
use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv};
1717
use rustc_middle::{bug, span_bug};
1818
use rustc_span::{Span, Symbol, sym};
1919
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
@@ -1840,15 +1840,32 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18401840
return Ok(call);
18411841
}
18421842

1843+
fn llvm_alignment<'ll, 'tcx>(
1844+
bx: &mut Builder<'_, 'll, 'tcx>,
1845+
alignment: SimdAlign,
1846+
vector_ty: Ty<'tcx>,
1847+
element_ty: Ty<'tcx>,
1848+
) -> u64 {
1849+
match alignment {
1850+
SimdAlign::Unaligned => 1,
1851+
SimdAlign::Element => bx.align_of(element_ty).bytes(),
1852+
SimdAlign::Vector => bx.align_of(vector_ty).bytes(),
1853+
}
1854+
}
1855+
18431856
if name == sym::simd_masked_load {
1844-
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
1857+
// simd_masked_load<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
18451858
// * N: number of elements in the input vectors
18461859
// * T: type of the element to load
18471860
// * M: any integer width is supported, will be truncated to i1
18481861
// Loads contiguous elements from memory behind `pointer`, but only for
18491862
// those lanes whose `mask` bit is enabled.
18501863
// The memory addresses corresponding to the “off” lanes are not accessed.
18511864

1865+
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1866+
.unwrap_leaf()
1867+
.to_simd_alignment();
1868+
18521869
// The element type of the "mask" argument must be a signed integer type of any width
18531870
let mask_ty = in_ty;
18541871
let (mask_len, mask_elem) = (in_len, in_elem);
@@ -1905,7 +1922,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19051922
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
19061923

19071924
// Alignment of T, must be a constant integer value:
1908-
let alignment = bx.align_of(values_elem).bytes();
1925+
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);
19091926

19101927
let llvm_pointer = bx.type_ptr();
19111928

@@ -1932,14 +1949,18 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19321949
}
19331950

19341951
if name == sym::simd_masked_store {
1935-
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
1952+
// simd_masked_store<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
19361953
// * N: number of elements in the input vectors
19371954
// * T: type of the element to load
19381955
// * M: any integer width is supported, will be truncated to i1
19391956
// Stores contiguous elements to memory behind `pointer`, but only for
19401957
// those lanes whose `mask` bit is enabled.
19411958
// The memory addresses corresponding to the “off” lanes are not accessed.
19421959

1960+
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1961+
.unwrap_leaf()
1962+
.to_simd_alignment();
1963+
19431964
// The element type of the "mask" argument must be a signed integer type of any width
19441965
let mask_ty = in_ty;
19451966
let (mask_len, mask_elem) = (in_len, in_elem);
@@ -1990,7 +2011,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19902011
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
19912012

19922013
// Alignment of T, must be a constant integer value:
1993-
let alignment = bx.align_of(values_elem).bytes();
2014+
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);
19942015

19952016
let llvm_pointer = bx.type_ptr();
19962017

compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
use either::Either;
2-
use rustc_abi::Endian;
2+
use rustc_abi::{BackendRepr, Endian};
33
use rustc_apfloat::ieee::{Double, Half, Quad, Single};
44
use rustc_apfloat::{Float, Round};
5-
use rustc_middle::mir::interpret::{InterpErrorKind, UndefinedBehaviorInfo};
6-
use rustc_middle::ty::FloatTy;
5+
use rustc_middle::mir::interpret::{InterpErrorKind, Pointer, UndefinedBehaviorInfo};
6+
use rustc_middle::ty::{FloatTy, SimdAlign};
77
use rustc_middle::{bug, err_ub_format, mir, span_bug, throw_unsup_format, ty};
88
use rustc_span::{Symbol, sym};
99
use tracing::trace;
1010

1111
use super::{
1212
ImmTy, InterpCx, InterpResult, Machine, MinMax, MulAddType, OpTy, PlaceTy, Provenance, Scalar,
13-
Size, interp_ok, throw_ub_format,
13+
Size, TyAndLayout, assert_matches, interp_ok, throw_ub_format,
1414
};
1515
use crate::interpret::Writeable;
1616

@@ -644,6 +644,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
644644
}
645645
}
646646
sym::simd_masked_load => {
647+
let dest_layout = dest.layout;
648+
647649
let (mask, mask_len) = self.project_to_simd(&args[0])?;
648650
let ptr = self.read_pointer(&args[1])?;
649651
let (default, default_len) = self.project_to_simd(&args[2])?;
@@ -652,6 +654,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
652654
assert_eq!(dest_len, mask_len);
653655
assert_eq!(dest_len, default_len);
654656

657+
self.check_simd_ptr_alignment(
658+
ptr,
659+
dest_layout,
660+
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
661+
.unwrap_leaf()
662+
.to_simd_alignment(),
663+
)?;
664+
655665
for i in 0..dest_len {
656666
let mask = self.read_immediate(&self.project_index(&mask, i)?)?;
657667
let default = self.read_immediate(&self.project_index(&default, i)?)?;
@@ -660,7 +670,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
660670
let val = if simd_element_to_bool(mask)? {
661671
// Size * u64 is implemented as always checked
662672
let ptr = ptr.wrapping_offset(dest.layout.size * i, self);
663-
let place = self.ptr_to_mplace(ptr, dest.layout);
673+
// we have already checked the alignment requirements
674+
let place = self.ptr_to_mplace_unaligned(ptr, dest.layout);
664675
self.read_immediate(&place)?
665676
} else {
666677
default
@@ -675,14 +686,23 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
675686

676687
assert_eq!(mask_len, vals_len);
677688

689+
self.check_simd_ptr_alignment(
690+
ptr,
691+
args[2].layout,
692+
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
693+
.unwrap_leaf()
694+
.to_simd_alignment(),
695+
)?;
696+
678697
for i in 0..vals_len {
679698
let mask = self.read_immediate(&self.project_index(&mask, i)?)?;
680699
let val = self.read_immediate(&self.project_index(&vals, i)?)?;
681700

682701
if simd_element_to_bool(mask)? {
683702
// Size * u64 is implemented as always checked
684703
let ptr = ptr.wrapping_offset(val.layout.size * i, self);
685-
let place = self.ptr_to_mplace(ptr, val.layout);
704+
// we have already checked the alignment requirements
705+
let place = self.ptr_to_mplace_unaligned(ptr, val.layout);
686706
self.write_immediate(*val, &place)?
687707
};
688708
}
@@ -753,6 +773,30 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
753773
FloatTy::F128 => self.float_minmax::<Quad>(left, right, op)?,
754774
})
755775
}
776+
777+
fn check_simd_ptr_alignment(
778+
&self,
779+
ptr: Pointer<Option<M::Provenance>>,
780+
vector_layout: TyAndLayout<'tcx>,
781+
alignment: SimdAlign,
782+
) -> InterpResult<'tcx> {
783+
assert_matches!(vector_layout.backend_repr, BackendRepr::SimdVector { .. });
784+
785+
let align = match alignment {
786+
ty::SimdAlign::Unaligned => {
787+
// The pointer is supposed to be unaligned, so no check is required.
788+
return interp_ok(());
789+
}
790+
ty::SimdAlign::Element => {
791+
// Take the alignment of the only field, which is an array and therefore has the same
792+
// alignment as the element type.
793+
vector_layout.field(self, 0).align.abi
794+
}
795+
ty::SimdAlign::Vector => vector_layout.align.abi,
796+
};
797+
798+
self.check_ptr_align(ptr, align)
799+
}
756800
}
757801

758802
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,8 @@ pub(crate) fn check_intrinsic_type(
695695
(1, 0, vec![param(0), param(0), param(0)], param(0))
696696
}
697697
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
698-
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
699-
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
698+
sym::simd_masked_load => (3, 1, vec![param(0), param(1), param(2)], param(2)),
699+
sym::simd_masked_store => (3, 1, vec![param(0), param(1), param(2)], tcx.types.unit),
700700
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
701701
sym::simd_insert | sym::simd_insert_dyn => {
702702
(2, 0, vec![param(0), tcx.types.u32, param(1)], param(0))

compiler/rustc_middle/src/ty/consts/int.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ pub enum AtomicOrdering {
3939
SeqCst = 4,
4040
}
4141

42+
/// An enum to represent the compiler-side view of `intrinsics::simd::SimdAlign`.
43+
#[derive(Debug, Copy, Clone)]
44+
pub enum SimdAlign {
45+
// These values must match `intrinsics::simd::SimdAlign`!
46+
Unaligned = 0,
47+
Element = 1,
48+
Vector = 2,
49+
}
50+
4251
impl std::fmt::Debug for ConstInt {
4352
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4453
let Self { int, signed, is_ptr_sized_integral } = *self;
@@ -350,6 +359,21 @@ impl ScalarInt {
350359
}
351360
}
352361

362+
#[inline]
363+
pub fn to_simd_alignment(self) -> SimdAlign {
364+
use SimdAlign::*;
365+
let val = self.to_u32();
366+
if val == Unaligned as u32 {
367+
Unaligned
368+
} else if val == Element as u32 {
369+
Element
370+
} else if val == Vector as u32 {
371+
Vector
372+
} else {
373+
panic!("not a valid simd alignment")
374+
}
375+
}
376+
353377
/// Converts the `ScalarInt` to `bool`.
354378
/// Panics if the `size` of the `ScalarInt` is not equal to 1 byte.
355379
/// Errors if it is not a valid `bool`.

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ pub use self::closure::{
7474
};
7575
pub use self::consts::{
7676
AnonConstKind, AtomicOrdering, Const, ConstInt, ConstKind, ConstToValTreeResult, Expr,
77-
ExprKind, ScalarInt, UnevaluatedConst, ValTree, ValTreeKind, Value,
77+
ExprKind, ScalarInt, SimdAlign, UnevaluatedConst, ValTree, ValTreeKind, Value,
7878
};
7979
pub use self::context::{
8080
CtxtInterners, CurrentGcx, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt, TyCtxtFeed, tls,

library/core/src/intrinsics/simd.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//!
33
//! In this module, a "vector" is any `repr(simd)` type.
44
5+
use crate::marker::ConstParamTy;
6+
57
/// Inserts an element into a vector, returning the updated vector.
68
///
79
/// `T` must be a vector with element type `U`, and `idx` must be `const`.
@@ -377,6 +379,19 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
377379
#[rustc_nounwind]
378380
pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
379381

382+
/// A type for alignment options for SIMD masked load/store intrinsics.
383+
#[derive(Debug, ConstParamTy, PartialEq, Eq)]
384+
pub enum SimdAlign {
385+
// These values must match the compiler's `SimdAlign` defined in
386+
// `rustc_middle/src/ty/consts/int.rs`!
387+
/// No alignment requirements on the pointer
388+
Unaligned = 0,
389+
/// The pointer must be aligned to the element type of the SIMD vector
390+
Element = 1,
391+
/// The pointer must be aligned to the SIMD vector type
392+
Vector = 2,
393+
}
394+
380395
/// Reads a vector of pointers.
381396
///
382397
/// `T` must be a vector.
@@ -392,13 +407,12 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
392407
/// `val`.
393408
///
394409
/// # Safety
395-
/// Unmasked values in `T` must be readable as if by `<ptr>::read` (e.g. aligned to the element
396-
/// type).
410+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
397411
///
398412
/// `mask` must only contain `0` or `!0` values.
399413
#[rustc_intrinsic]
400414
#[rustc_nounwind]
401-
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
415+
pub unsafe fn simd_masked_load<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T) -> T;
402416

403417
/// Writes to a vector of pointers.
404418
///
@@ -414,13 +428,12 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
414428
/// Otherwise if the corresponding value in `mask` is `0`, do nothing.
415429
///
416430
/// # Safety
417-
/// Unmasked values in `T` must be writeable as if by `<ptr>::write` (e.g. aligned to the element
418-
/// type).
431+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
419432
///
420433
/// `mask` must only contain `0` or `!0` values.
421434
#[rustc_intrinsic]
422435
#[rustc_nounwind]
423-
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T);
436+
pub unsafe fn simd_masked_store<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T);
424437

425438
/// Adds two simd vectors elementwise, with saturation.
426439
///

0 commit comments

Comments
 (0)