Skip to content

Commit 8a3d253

Browse files
committed
Ban non-array SIMD
1 parent a32d4a0 commit 8a3d253

File tree

91 files changed

+675
-818
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+675
-818
lines changed

compiler/rustc_hir_analysis/src/check/check.rs

+22-20
Original file line numberDiff line numberDiff line change
@@ -1063,20 +1063,29 @@ pub fn check_simd(tcx: TyCtxt<'_>, sp: Span, def_id: LocalDefId) {
10631063
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot be empty").emit();
10641064
return;
10651065
}
1066-
let e = fields[FieldIdx::ZERO].ty(tcx, args);
1067-
if !fields.iter().all(|f| f.ty(tcx, args) == e) {
1068-
struct_span_code_err!(tcx.dcx(), sp, E0076, "SIMD vector should be homogeneous")
1069-
.with_span_label(sp, "SIMD elements must have the same type")
1066+
1067+
let array_field = &fields[FieldIdx::ZERO];
1068+
let array_ty = array_field.ty(tcx, args);
1069+
let ty::Array(element_ty, len_const) = array_ty.kind() else {
1070+
struct_span_code_err!(
1071+
tcx.dcx(),
1072+
sp,
1073+
E0076,
1074+
"SIMD vector's only field must be an array"
1075+
)
1076+
.with_span_label(tcx.def_span(array_field.did), "not an array")
1077+
.emit();
1078+
return;
1079+
};
1080+
1081+
if let Some(second_field) = fields.get(FieldIdx::from_u32(1)) {
1082+
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot have multiple fields")
1083+
.with_span_label(tcx.def_span(second_field.did), "excess field")
10701084
.emit();
10711085
return;
10721086
}
10731087

1074-
let len = if let ty::Array(_ty, c) = e.kind() {
1075-
c.try_eval_target_usize(tcx, tcx.param_env(def.did()))
1076-
} else {
1077-
Some(fields.len() as u64)
1078-
};
1079-
if let Some(len) = len {
1088+
if let Some(len) = len_const.try_eval_target_usize(tcx, tcx.param_env(def.did())) {
10801089
if len == 0 {
10811090
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot be empty").emit();
10821091
return;
@@ -1096,16 +1105,9 @@ pub fn check_simd(tcx: TyCtxt<'_>, sp: Span, def_id: LocalDefId) {
10961105
// These are scalar types which directly match a "machine" type
10971106
// Yes: Integers, floats, "thin" pointers
10981107
// No: char, "fat" pointers, compound types
1099-
match e.kind() {
1100-
ty::Param(_) => (), // pass struct<T>(T, T, T, T) through, let monomorphization catch errors
1101-
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _) => (), // struct(u8, u8, u8, u8) is ok
1102-
ty::Array(t, _) if matches!(t.kind(), ty::Param(_)) => (), // pass struct<T>([T; N]) through, let monomorphization catch errors
1103-
ty::Array(t, _clen)
1104-
if matches!(
1105-
t.kind(),
1106-
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _)
1107-
) =>
1108-
{ /* struct([f32; 4]) is ok */ }
1108+
match element_ty.kind() {
1109+
ty::Param(_) => (), // pass struct<T>([T; 4]) through, let monomorphization catch errors
1110+
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _) => (), // struct([u8; 4]) is ok
11091111
_ => {
11101112
struct_span_code_err!(
11111113
tcx.dcx(),

compiler/rustc_middle/src/ty/sty.rs

+15-23
Original file line numberDiff line numberDiff line change
@@ -1091,29 +1091,21 @@ impl<'tcx> Ty<'tcx> {
10911091
}
10921092

10931093
pub fn simd_size_and_type(self, tcx: TyCtxt<'tcx>) -> (u64, Ty<'tcx>) {
1094-
match self.kind() {
1095-
Adt(def, args) => {
1096-
assert!(def.repr().simd(), "`simd_size_and_type` called on non-SIMD type");
1097-
let variant = def.non_enum_variant();
1098-
let f0_ty = variant.fields[FieldIdx::ZERO].ty(tcx, args);
1099-
1100-
match f0_ty.kind() {
1101-
// If the first field is an array, we assume it is the only field and its
1102-
// elements are the SIMD components.
1103-
Array(f0_elem_ty, f0_len) => {
1104-
// FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
1105-
// The way we evaluate the `N` in `[T; N]` here only works since we use
1106-
// `simd_size_and_type` post-monomorphization. It will probably start to ICE
1107-
// if we use it in generic code. See the `simd-array-trait` ui test.
1108-
(f0_len.eval_target_usize(tcx, ParamEnv::empty()), *f0_elem_ty)
1109-
}
1110-
// Otherwise, the fields of this Adt are the SIMD components (and we assume they
1111-
// all have the same type).
1112-
_ => (variant.fields.len() as u64, f0_ty),
1113-
}
1114-
}
1115-
_ => bug!("`simd_size_and_type` called on invalid type"),
1116-
}
1094+
let Adt(def, args) = self.kind() else {
1095+
bug!("`simd_size_and_type` called on invalid type")
1096+
};
1097+
assert!(def.repr().simd(), "`simd_size_and_type` called on non-SIMD type");
1098+
let variant = def.non_enum_variant();
1099+
assert_eq!(variant.fields.len(), 1);
1100+
let field_ty = variant.fields[FieldIdx::ZERO].ty(tcx, args);
1101+
let Array(f0_elem_ty, f0_len) = field_ty.kind() else {
1102+
bug!("Simd type has non-array field type {field_ty:?}")
1103+
};
1104+
// FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
1105+
// The way we evaluate the `N` in `[T; N]` here only works since we use
1106+
// `simd_size_and_type` post-monomorphization. It will probably start to ICE
1107+
// if we use it in generic code. See the `simd-array-trait` ui test.
1108+
(f0_len.eval_target_usize(tcx, ParamEnv::empty()), *f0_elem_ty)
11171109
}
11181110

11191111
#[inline]

tests/codegen/align-byval-vector.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ trait Freeze {}
2121
trait Copy {}
2222

2323
#[repr(simd)]
24-
pub struct i32x4(i32, i32, i32, i32);
24+
pub struct i32x4([i32; 4]);
2525

2626
#[repr(C)]
2727
pub struct Foo {
@@ -47,12 +47,12 @@ extern "C" {
4747
}
4848

4949
pub fn main() {
50-
unsafe { f(Foo { a: i32x4(1, 2, 3, 4), b: 0 }) }
50+
unsafe { f(Foo { a: i32x4([1, 2, 3, 4]), b: 0 }) }
5151

5252
unsafe {
5353
g(DoubleFoo {
54-
one: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
55-
two: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
54+
one: Foo { a: i32x4([1, 2, 3, 4]), b: 0 },
55+
two: Foo { a: i32x4([1, 2, 3, 4]), b: 0 },
5656
})
5757
}
5858
}

tests/codegen/const-vector.rs

+10-26
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,11 @@
1313
// Setting up structs that can be used as const vectors
1414
#[repr(simd)]
1515
#[derive(Clone)]
16-
pub struct i8x2(i8, i8);
16+
pub struct i8x2([i8; 2]);
1717

1818
#[repr(simd)]
1919
#[derive(Clone)]
20-
pub struct i8x2_arr([i8; 2]);
21-
22-
#[repr(simd)]
23-
#[derive(Clone)]
24-
pub struct f32x2(f32, f32);
25-
26-
#[repr(simd)]
27-
#[derive(Clone)]
28-
pub struct f32x2_arr([f32; 2]);
20+
pub struct f32x2([f32; 2]);
2921

3022
#[repr(simd, packed)]
3123
#[derive(Copy, Clone)]
@@ -35,42 +27,34 @@ pub struct Simd<T, const N: usize>([T; N]);
3527
// that they are called with a const vector
3628

3729
extern "unadjusted" {
38-
#[no_mangle]
3930
fn test_i8x2(a: i8x2);
4031
}
4132

4233
extern "unadjusted" {
43-
#[no_mangle]
4434
fn test_i8x2_two_args(a: i8x2, b: i8x2);
4535
}
4636

4737
extern "unadjusted" {
48-
#[no_mangle]
4938
fn test_i8x2_mixed_args(a: i8x2, c: i32, b: i8x2);
5039
}
5140

5241
extern "unadjusted" {
53-
#[no_mangle]
54-
fn test_i8x2_arr(a: i8x2_arr);
42+
fn test_i8x2_arr(a: i8x2);
5543
}
5644

5745
extern "unadjusted" {
58-
#[no_mangle]
5946
fn test_f32x2(a: f32x2);
6047
}
6148

6249
extern "unadjusted" {
63-
#[no_mangle]
64-
fn test_f32x2_arr(a: f32x2_arr);
50+
fn test_f32x2_arr(a: f32x2);
6551
}
6652

6753
extern "unadjusted" {
68-
#[no_mangle]
6954
fn test_simd(a: Simd<i32, 4>);
7055
}
7156

7257
extern "unadjusted" {
73-
#[no_mangle]
7458
fn test_simd_unaligned(a: Simd<i32, 3>);
7559
}
7660

@@ -81,22 +65,22 @@ extern "unadjusted" {
8165
pub fn do_call() {
8266
unsafe {
8367
// CHECK: call void @test_i8x2(<2 x i8> <i8 32, i8 64>
84-
test_i8x2(const { i8x2(32, 64) });
68+
test_i8x2(const { i8x2([32, 64]) });
8569

8670
// CHECK: call void @test_i8x2_two_args(<2 x i8> <i8 32, i8 64>, <2 x i8> <i8 8, i8 16>
87-
test_i8x2_two_args(const { i8x2(32, 64) }, const { i8x2(8, 16) });
71+
test_i8x2_two_args(const { i8x2([32, 64]) }, const { i8x2([8, 16]) });
8872

8973
// CHECK: call void @test_i8x2_mixed_args(<2 x i8> <i8 32, i8 64>, i32 43, <2 x i8> <i8 8, i8 16>
90-
test_i8x2_mixed_args(const { i8x2(32, 64) }, 43, const { i8x2(8, 16) });
74+
test_i8x2_mixed_args(const { i8x2([32, 64]) }, 43, const { i8x2([8, 16]) });
9175

9276
// CHECK: call void @test_i8x2_arr(<2 x i8> <i8 32, i8 64>
93-
test_i8x2_arr(const { i8x2_arr([32, 64]) });
77+
test_i8x2_arr(const { i8x2([32, 64]) });
9478

9579
// CHECK: call void @test_f32x2(<2 x float> <float 0x3FD47AE140000000, float 0x3FE47AE140000000>
96-
test_f32x2(const { f32x2(0.32, 0.64) });
80+
test_f32x2(const { f32x2([0.32, 0.64]) });
9781

9882
// CHECK: void @test_f32x2_arr(<2 x float> <float 0x3FD47AE140000000, float 0x3FE47AE140000000>
99-
test_f32x2_arr(const { f32x2_arr([0.32, 0.64]) });
83+
test_f32x2_arr(const { f32x2([0.32, 0.64]) });
10084

10185
// CHECK: call void @test_simd(<4 x i32> <i32 2, i32 4, i32 6, i32 8>
10286
test_simd(const { Simd::<i32, 4>([2, 4, 6, 8]) });

tests/codegen/repr/transparent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ pub extern "C" fn test_Nested2(_: Nested2) -> Nested2 {
132132
}
133133

134134
#[repr(simd)]
135-
struct f32x4(f32, f32, f32, f32);
135+
struct f32x4([f32; 4]);
136136

137137
#[repr(transparent)]
138138
pub struct Vector(f32x4);

tests/codegen/simd-intrinsic/simd-intrinsic-float-abs.rs

+7-12
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,19 @@
77

88
#[repr(simd)]
99
#[derive(Copy, Clone, PartialEq, Debug)]
10-
pub struct f32x2(pub f32, pub f32);
10+
pub struct f32x2(pub [f32; 2]);
1111

1212
#[repr(simd)]
1313
#[derive(Copy, Clone, PartialEq, Debug)]
14-
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
14+
pub struct f32x4(pub [f32; 4]);
1515

1616
#[repr(simd)]
1717
#[derive(Copy, Clone, PartialEq, Debug)]
18-
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
19-
pub f32, pub f32, pub f32, pub f32);
18+
pub struct f32x8(pub [f32; 8]);
2019

2120
#[repr(simd)]
2221
#[derive(Copy, Clone, PartialEq, Debug)]
23-
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
24-
pub f32, pub f32, pub f32, pub f32,
25-
pub f32, pub f32, pub f32, pub f32,
26-
pub f32, pub f32, pub f32, pub f32);
22+
pub struct f32x16(pub [f32; 16]);
2723

2824
extern "rust-intrinsic" {
2925
fn simd_fabs<T>(x: T) -> T;
@@ -59,16 +55,15 @@ pub unsafe fn fabs_32x16(a: f32x16) -> f32x16 {
5955

6056
#[repr(simd)]
6157
#[derive(Copy, Clone, PartialEq, Debug)]
62-
pub struct f64x2(pub f64, pub f64);
58+
pub struct f64x2(pub [f64; 2]);
6359

6460
#[repr(simd)]
6561
#[derive(Copy, Clone, PartialEq, Debug)]
66-
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
62+
pub struct f64x4(pub [f64; 4]);
6763

6864
#[repr(simd)]
6965
#[derive(Copy, Clone, PartialEq, Debug)]
70-
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
71-
pub f64, pub f64, pub f64, pub f64);
66+
pub struct f64x8(pub [f64; 8]);
7267

7368
// CHECK-LABEL: @fabs_64x4
7469
#[no_mangle]

tests/codegen/simd-intrinsic/simd-intrinsic-float-ceil.rs

+7-12
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,19 @@
77

88
#[repr(simd)]
99
#[derive(Copy, Clone, PartialEq, Debug)]
10-
pub struct f32x2(pub f32, pub f32);
10+
pub struct f32x2(pub [f32; 2]);
1111

1212
#[repr(simd)]
1313
#[derive(Copy, Clone, PartialEq, Debug)]
14-
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
14+
pub struct f32x4(pub [f32; 4]);
1515

1616
#[repr(simd)]
1717
#[derive(Copy, Clone, PartialEq, Debug)]
18-
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
19-
pub f32, pub f32, pub f32, pub f32);
18+
pub struct f32x8(pub [f32; 8]);
2019

2120
#[repr(simd)]
2221
#[derive(Copy, Clone, PartialEq, Debug)]
23-
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
24-
pub f32, pub f32, pub f32, pub f32,
25-
pub f32, pub f32, pub f32, pub f32,
26-
pub f32, pub f32, pub f32, pub f32);
22+
pub struct f32x16(pub [f32; 16]);
2723

2824
extern "rust-intrinsic" {
2925
fn simd_ceil<T>(x: T) -> T;
@@ -59,16 +55,15 @@ pub unsafe fn ceil_32x16(a: f32x16) -> f32x16 {
5955

6056
#[repr(simd)]
6157
#[derive(Copy, Clone, PartialEq, Debug)]
62-
pub struct f64x2(pub f64, pub f64);
58+
pub struct f64x2(pub [f64; 2]);
6359

6460
#[repr(simd)]
6561
#[derive(Copy, Clone, PartialEq, Debug)]
66-
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
62+
pub struct f64x4(pub [f64; 4]);
6763

6864
#[repr(simd)]
6965
#[derive(Copy, Clone, PartialEq, Debug)]
70-
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
71-
pub f64, pub f64, pub f64, pub f64);
66+
pub struct f64x8(pub [f64; 8]);
7267

7368
// CHECK-LABEL: @ceil_64x4
7469
#[no_mangle]

tests/codegen/simd-intrinsic/simd-intrinsic-float-cos.rs

+7-12
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,19 @@
77

88
#[repr(simd)]
99
#[derive(Copy, Clone, PartialEq, Debug)]
10-
pub struct f32x2(pub f32, pub f32);
10+
pub struct f32x2(pub [f32; 2]);
1111

1212
#[repr(simd)]
1313
#[derive(Copy, Clone, PartialEq, Debug)]
14-
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
14+
pub struct f32x4(pub [f32; 4]);
1515

1616
#[repr(simd)]
1717
#[derive(Copy, Clone, PartialEq, Debug)]
18-
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
19-
pub f32, pub f32, pub f32, pub f32);
18+
pub struct f32x8(pub [f32; 8]);
2019

2120
#[repr(simd)]
2221
#[derive(Copy, Clone, PartialEq, Debug)]
23-
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
24-
pub f32, pub f32, pub f32, pub f32,
25-
pub f32, pub f32, pub f32, pub f32,
26-
pub f32, pub f32, pub f32, pub f32);
22+
pub struct f32x16(pub [f32; 16]);
2723

2824
extern "rust-intrinsic" {
2925
fn simd_fcos<T>(x: T) -> T;
@@ -59,16 +55,15 @@ pub unsafe fn fcos_32x16(a: f32x16) -> f32x16 {
5955

6056
#[repr(simd)]
6157
#[derive(Copy, Clone, PartialEq, Debug)]
62-
pub struct f64x2(pub f64, pub f64);
58+
pub struct f64x2(pub [f64; 2]);
6359

6460
#[repr(simd)]
6561
#[derive(Copy, Clone, PartialEq, Debug)]
66-
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
62+
pub struct f64x4(pub [f64; 4]);
6763

6864
#[repr(simd)]
6965
#[derive(Copy, Clone, PartialEq, Debug)]
70-
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
71-
pub f64, pub f64, pub f64, pub f64);
66+
pub struct f64x8(pub [f64; 8]);
7267

7368
// CHECK-LABEL: @fcos_64x4
7469
#[no_mangle]

0 commit comments

Comments
 (0)