Skip to content

Commit 5836d91

Browse files
authored
[flang] add ABI argument attributes in indirect calls (#126896)
Last piece that implements the TODO for sret and byval setting on indirect calls. This includes a fix to the codegen last patch. I thought types in in type attributes were automatically converted in dialect conversion passes, but that is not the case. The sret and byval type needs to be converted to llvm types in codegen (mlir FuncOp conversion is doing a similar conversion).
1 parent eff3c34 commit 5836d91

File tree

5 files changed

+115
-11
lines changed

5 files changed

+115
-11
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

+30-2
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
593593
call, resultTys, adaptor.getOperands(),
594594
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
595595
adaptor.getOperands().size()));
596-
if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
597-
llvmCall.setArgAttrsAttr(argAttrs);
596+
if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
597+
// sret and byval type needs to be converted.
598+
auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
599+
return mlir::TypeAttr::get(convertType(
600+
llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
601+
};
602+
llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
603+
for (auto argAttrs : argAttrsArray) {
604+
llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
605+
for (const mlir::NamedAttribute &attr :
606+
llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
607+
if (attr.getName().getValue() ==
608+
mlir::LLVM::LLVMDialect::getByValAttrName()) {
609+
convertedAttrs.push_back(rewriter.getNamedAttr(
610+
mlir::LLVM::LLVMDialect::getByValAttrName(),
611+
convertTypeAttr(attr)));
612+
} else if (attr.getName().getValue() ==
613+
mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
614+
convertedAttrs.push_back(rewriter.getNamedAttr(
615+
mlir::LLVM::LLVMDialect::getStructRetAttrName(),
616+
convertTypeAttr(attr)));
617+
} else {
618+
convertedAttrs.push_back(attr);
619+
}
620+
}
621+
newArgAttrsArray.emplace_back(
622+
mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
623+
}
624+
llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
625+
}
598626
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
599627
llvmCall.setResAttrsAttr(resAttrs);
600628
return mlir::success();

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

+34-9
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
534534
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
535535
fir::CallOp newCall;
536536
if (callOp.getCallee()) {
537-
newCall =
538-
rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
537+
newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
538+
newResTys, newOpers);
539539
} else {
540-
// TODO: llvm dialect must be updated to propagate argument on
541-
// attributes for indirect calls. See:
542-
// https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
543-
if (hasByValOrSRetArgs(newInTyAndAttrs))
544-
TODO(loc,
545-
"passing argument or result on the stack in indirect calls");
546540
newOpers[0].setType(mlir::FunctionType::get(
547541
callOp.getContext(),
548542
mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
549-
newCall = rewriter->create<A>(loc, newResTys, newOpers);
543+
newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
544+
// Set ABI argument attributes on call operation since they are not
545+
// accessible via a FuncOp in indirect calls.
546+
if (hasByValOrSRetArgs(newInTyAndAttrs)) {
547+
llvm::SmallVector<mlir::Attribute> argAttrsArray;
548+
for (const auto &arg :
549+
llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
550+
newInTyAndAttrs)
551+
.drop_front(dropFront)) {
552+
mlir::NamedAttrList argAttrs;
553+
const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
554+
if (attr.isByVal()) {
555+
mlir::Type elemType =
556+
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
557+
argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
558+
mlir::TypeAttr::get(elemType));
559+
} else if (attr.isSRet()) {
560+
mlir::Type elemType =
561+
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
562+
argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
563+
mlir::TypeAttr::get(elemType));
564+
if (auto align = attr.getAlignment()) {
565+
argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
566+
rewriter->getIntegerAttr(
567+
rewriter->getIntegerType(32), align));
568+
}
569+
}
570+
argAttrsArray.emplace_back(
571+
argAttrs.getDictionary(rewriter->getContext()));
572+
}
573+
newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
574+
}
550575
}
551576
LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
552577
if (wrap)

flang/test/Fir/convert-to-llvm.fir

+14
Original file line numberDiff line numberDiff line change
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
28712871
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
28722872
return %0 : i16
28732873
}
2874+
2875+
// CHECK-LABEL: @test_byval
2876+
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
2877+
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
2878+
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
2879+
return
2880+
}
2881+
2882+
// CHECK-LABEL: @test_sret
2883+
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
2884+
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
2885+
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
2886+
return
2887+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that ABI attributes are set in indirect calls to BIND(C) functions.
2+
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
3+
4+
func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
5+
%0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
6+
%1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
7+
fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
8+
return
9+
}
10+
// CHECK-LABEL: func.func @test(
11+
// CHECK-SAME: %[[VAL_0:.*]]: () -> (),
12+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
13+
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
14+
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
15+
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
16+
// CHECK: %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
17+
// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
18+
// CHECK: fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
19+
// CHECK: fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
20+
// CHECK: llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
21+
// CHECK: return
22+
// CHECK: }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
!REQUIRES: x86-registered-target
2+
!REQUIRES: flang-supports-f128-math
3+
!RUN: %flang_fc1 -emit-llvm -triple x86_64-unknown-linux-gnu %s -o - | FileCheck %s
4+
5+
! Test ABI of indirect calls is properly implemented in the LLVM IR.
6+
7+
subroutine foo(func_ptr, z)
8+
interface
9+
complex(16) function func_ptr()
10+
end function
11+
end interface
12+
complex(16) :: z
13+
! CHECK: call void %{{.*}}(ptr sret({ fp128, fp128 }) align 16 %{{.*}})
14+
z = func_ptr()
15+
end subroutine

0 commit comments

Comments
 (0)