Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using make_direct_deprecated() in extern "ptx-kernel" #133932

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions compiler/rustc_target/src/callconv/nvptx64.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{ArgAttribute, ArgAttributes, ArgExtension, CastTarget};
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
use crate::abi::call::{ArgAbi, FnAbi, Reg, Size, Uniform};
use crate::abi::{HasDataLayout, TyAbiInterface};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
Expand Down Expand Up @@ -53,21 +53,37 @@ where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
let align_bytes = arg.layout.align.abi.bytes();
match arg.mode {
super::PassMode::Ignore | super::PassMode::Direct(_) => return,
super::PassMode::Pair(_, _) => {}
super::PassMode::Cast { .. } => unreachable!(),
super::PassMode::Indirect { .. } => {}
}

// FIXME only allow structs and wide pointers here
// panic!(
// "`extern \"ptx-kernel\"` doesn't allow passing types other than primitives and structs"
// );

let align_bytes = arg.layout.align.abi.bytes();

let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform::new(unit, Size::from_bytes(2 * align_bytes)));
let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
if arg.layout.size.bytes() / align_bytes == 1 {
// Make sure we pass the struct as array at the LLVM IR level and not as a single integer.
arg.cast_to(CastTarget {
prefix: [Some(unit), None, None, None, None, None, None, None],
rest: Uniform::new(unit, Size::ZERO),
attrs: ArgAttributes::new(),
});
} else {
// FIXME: find a better way to do this. See https://github.com/rust-lang/rust/issues/117271.
arg.make_direct_deprecated();
arg.cast_to(Uniform::new(unit, arg.layout.size));
}
}

Expand Down
11 changes: 3 additions & 8 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,21 +489,16 @@ fn fn_abi_sanity_check<'tcx>(
// have to allow it -- but we absolutely shouldn't let any more targets do
// that. (Also see <https://github.com/rust-lang/rust/issues/115666>.)
//
// The unstable abi `PtxKernel` also uses Direct for now.
// It needs to switch to something else before stabilization can happen.
// (See issue: https://github.com/rust-lang/rust/issues/117271)
//
// And finally the unadjusted ABI is ill specified and uses Direct for all
// args, but unfortunately we need it for calling certain LLVM intrinsics.
// The unadjusted ABI also uses Direct for all args and is ill-specified,
// but unfortunately we need it for calling certain LLVM intrinsics.

match spec_abi {
ExternAbi::Unadjusted => {}
ExternAbi::PtxKernel => {}
ExternAbi::C { unwind: _ }
if matches!(&*tcx.sess.target.arch, "wasm32" | "wasm64") => {}
_ => {
panic!(
"`PassMode::Direct` for aggregates only allowed for \"unadjusted\" and \"ptx-kernel\" functions and on wasm\n\
"`PassMode::Direct` for aggregates only allowed for \"unadjusted\" functions and on wasm\n\
Problematic type: {:#?}",
arg.layout,
);
Expand Down
18 changes: 1 addition & 17 deletions tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,6 @@ pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {}
//pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {}

// CHECK: .visible .entry f_u32_slice_arg(
// CHECK: .param .u64 f_u32_slice_arg_param_0
// CHECK: .param .u64 f_u32_slice_arg_param_1
// CHECK: .param .align 8 .b8 f_u32_slice_arg_param_0[16]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't guarantee any ABI for wide pointers anyway and if this is breaks the ABI all users can be updated to pass #[repr(C)] struct Slice<'a, T> { ptr: *const (), len: usize, _marker: PhantomData<&'a [T]> } on the host side anyway (and preferably also the gpu side to avoid depending on the unstable ABI).

#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {}

// CHECK: .visible .entry f_tuple_u8_u8_arg(
// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {}

// CHECK: .visible .entry f_tuple_u32_u32_arg(
// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {}

// CHECK: .visible .entry f_tuple_u8_u8_u32_arg(
// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuples don't have a stable ABI for any calling convention. And can always be replaced with structs anyway.

Loading