Skip to content

Commit 352abba

Browse files
committed
Fix a bug in the ptx-kernel calling convention where structs was passed indirectly
Structs being passed indirectly is suprpising and have a high chance not to work as the device and host usually do not share memory.
1 parent 297273c commit 352abba

File tree

5 files changed

+98
-9
lines changed

5 files changed

+98
-9
lines changed

compiler/rustc_middle/src/ty/layout.rs

+16
Original file line numberDiff line numberDiff line change
@@ -2568,6 +2568,22 @@ where
25682568

25692569
pointee_info
25702570
}
2571+
2572+
fn is_adt(this: TyAndLayout<'tcx>) -> bool {
2573+
matches!(this.ty.kind(), ty::Adt(..))
2574+
}
2575+
2576+
fn is_never(this: TyAndLayout<'tcx>) -> bool {
2577+
this.ty.kind() == &ty::Never
2578+
}
2579+
2580+
fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
2581+
matches!(this.ty.kind(), ty::Tuple(..))
2582+
}
2583+
2584+
fn is_unit(this: TyAndLayout<'tcx>) -> bool {
2585+
matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
2586+
}
25712587
}
25722588

25732589
impl<'tcx> ty::Instance<'tcx> {

compiler/rustc_middle/src/ty/list.rs

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ impl<T> List<T> {
6161
static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
6262
unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
6363
}
64+
65+
pub fn len(&self) -> usize {
66+
self.len
67+
}
6468
}
6569

6670
impl<T: Copy> List<T> {

compiler/rustc_target/src/abi/call/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
696696
"sparc" => sparc::compute_abi_info(cx, self),
697697
"sparc64" => sparc64::compute_abi_info(cx, self),
698698
"nvptx" => nvptx::compute_abi_info(self),
699-
"nvptx64" => nvptx64::compute_abi_info(self),
699+
"nvptx64" => {
700+
if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
701+
nvptx64::compute_ptx_kernel_abi_info(cx, self)
702+
} else {
703+
nvptx64::compute_abi_info(self)
704+
}
705+
}
700706
"hexagon" => hexagon::compute_abi_info(self),
701707
"riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
702708
"wasm32" | "wasm64" => {
+39-8
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
1-
// Reference: PTX Writer's Guide to Interoperability
2-
// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability
3-
4-
use crate::abi::call::{ArgAbi, FnAbi};
1+
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
2+
use crate::abi::{HasDataLayout, TyAbiInterface};
53

64
fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
75
if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
86
ret.make_indirect();
9-
} else {
10-
ret.extend_integer_width_to(64);
117
}
128
}
139

1410
fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
1511
if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
1612
arg.make_indirect();
17-
} else {
18-
arg.extend_integer_width_to(64);
13+
}
14+
}
15+
16+
fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
17+
where
18+
Ty: TyAbiInterface<'a, C> + Copy,
19+
C: HasDataLayout,
20+
{
21+
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
22+
let align_bytes = arg.layout.align.abi.bytes();
23+
24+
let unit = match align_bytes {
25+
1 => Reg::i8(),
26+
2 => Reg::i16(),
27+
4 => Reg::i32(),
28+
8 => Reg::i64(),
29+
16 => Reg::i128(),
30+
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
31+
};
32+
arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
1933
}
2034
}
2135

@@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
3145
classify_arg(arg);
3246
}
3347
}
48+
49+
pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
50+
where
51+
Ty: TyAbiInterface<'a, C> + Copy,
52+
C: HasDataLayout,
53+
{
54+
if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
55+
panic!("Kernels should not return anything other than () or !");
56+
}
57+
58+
for arg in &mut fn_abi.args {
59+
if arg.is_ignore() {
60+
continue;
61+
}
62+
classify_arg_kernel(cx, arg);
63+
}
64+
}

compiler/rustc_target/src/abi/mod.rs

+32
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
12501250
cx: &C,
12511251
offset: Size,
12521252
) -> Option<PointeeInfo>;
1253+
fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
1254+
fn is_never(this: TyAndLayout<'a, Self>) -> bool;
1255+
fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
1256+
fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
12531257
}
12541258

12551259
impl<'a, Ty> TyAndLayout<'a, Ty> {
@@ -1291,6 +1295,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
12911295
_ => false,
12921296
}
12931297
}
1298+
1299+
pub fn is_adt<C>(self) -> bool
1300+
where
1301+
Ty: TyAbiInterface<'a, C>,
1302+
{
1303+
Ty::is_adt(self)
1304+
}
1305+
1306+
pub fn is_never<C>(self) -> bool
1307+
where
1308+
Ty: TyAbiInterface<'a, C>,
1309+
{
1310+
Ty::is_never(self)
1311+
}
1312+
1313+
pub fn is_tuple<C>(self) -> bool
1314+
where
1315+
Ty: TyAbiInterface<'a, C>,
1316+
{
1317+
Ty::is_tuple(self)
1318+
}
1319+
1320+
pub fn is_unit<C>(self) -> bool
1321+
where
1322+
Ty: TyAbiInterface<'a, C>,
1323+
{
1324+
Ty::is_unit(self)
1325+
}
12941326
}
12951327

12961328
impl<'a, Ty> TyAndLayout<'a, Ty> {

0 commit comments

Comments
 (0)