@@ -18,94 +18,44 @@ using MPI: MPI
1818# return mpi.finalize(; location)
1919# end
2020
21- # TODO change to this kind of MLIR
22- # module {
23- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
24- # func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () {
25- # %comm = llvm.load %comm_ptr : !llvm.ptr -> i32
26- # %world_ptr = arith.constant dense<0x0asdfa> : tensor<i32>
27- # memref.get_global # global variable MPI_COMM_GLOBAL
28- # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
29- # func.return
30- # }
31- # func.func @real_$sym_name() -> tensor<> {
32- # %rank_ptr = stablehlo.constant dense<-1> : tensor<i32> # this is a placeholder
33- # %rank = enzymexla.jit_call @$sym_name(%world_ptr, %rank_ptr) {
34- # output_operand_alias = [
35- # #stablehlo.output_operand_alias<output_tuple_indices = [],
36- # operand_index = 1,
37- # operand_tuple_indices = []>
38- # ]
39- # }
40- # }
41- # }
42-
4321function comm_rank(; location= mlir_stacktrace(" mpi.comm_rank" , @__FILE__, @__LINE__))
4422 sym_name = " enzymexla_wrapper_MPI_Comm_rank"
45- # sym_attr = IR.FlatSymbolRefAttribute(sym_name)
46- comm = MPI. COMM_WORLD
47-
48- @show IR. mmodule()
23+ sym_attr = IR. FlatSymbolRefAttribute(sym_name)
4924
50- # memref.global constant @MPI_COMM_WORLD : memref<i32>
51- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
25+ # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast
26+ # DONT LOAD FROM THEM!
27+ IR. inject!(" MPI_COMM_WORLD" , " llvm.mlir.global constant @MPI_COMM_WORLD() : i32" )
28+ IR. inject!(" MPI_Comm_rank" , " llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32" )
5229
5330 # ! format: off
54- # IR.tryinjectop!("MPI_COMM_WORLD", "memref.global @MPI_COMM_WORLD : memref<i32>")
55- # IR.tryinjectop!("MPI_Comm_rank", "module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 }")
56- IR. inject!(" $(sym_name) _jit" , """
57- func.func @$(sym_name) _jit(%rank_ptr : !llvm.ptr) -> () {
58- %comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
59- %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
31+ IR. inject!(sym_name, """
32+ func.func @$sym_name (%rank_ptr : !llvm.ptr) -> () {
33+ %comm_ptr = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
6034 %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
6135 %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
6236 func.return
6337 }
6438 """ )
65- @show res
66- # ! format: on
67-
68- # %comm_ref = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
69- # %comm = llvm.ptrtoint %comm_ref : !llvm.ptr to i32
70-
71- # ! format: off
72- # return Reactant.Ops.hlo_call("""module {
73- # memref.global constant @MPI_COMM_WORLD : memref<i32>
74- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
75- # func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () {
76- # %comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
77- # %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
78- # %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
79- # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
80- # func.return
81- # }
82- # func.func @$sym_name() -> tensor<i32> {
83- # %rank_placeholder = stablehlo.constant dense<-1> : tensor<i32>
84- # %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) {
85- # output_operand_aliases = [
86- # #stablehlo.output_operand_alias<output_tuple_indices = [],
87- # operand_index = 1,
88- # operand_tuple_indices = []>
89- # ]
90- # } : (tensor<i32>) -> (tensor<i32>)
91- # func.return %rank : tensor<i32>
92- # }
93- # }"""; func_name=sym_name)
9439 # ! format: on
40+ rank_placeholder = Reactant. Ops. constant(fill(Cint(- 1 )))
41+ output_operand_aliases = IR. Attribute([
42+ IR. Attribute(
43+ MLIR. API. stablehloOutputOperandAliasGet(
44+ MLIR. IR. context(), 0 , C_NULL , 0 , 0 , C_NULL
45+ ),
46+ ),
47+ ])
9548
96- # NOTE we assume here that `MPI_Comm` is of word-size
97- # comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm))
98- # value_out = Reactant.Ops.constant(fill(Cint(-1)))
99- # inputs = IR.Value[comm.mlir_data, value_out.mlir_data]
100-
101- # tensor_int_type = IR.TensorType(Int[], IR.Type(Cint))
102- # signature = IR.Type[tensor_int_type, tensor_int_type]
103-
104- # # TODO output_operand_aliases
105- # res = IR.result(
106- # enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2
107- # )
108- # return TracedRNumber{Cint}((), res)
49+ res = IR. result(
50+ enzymexla. jit_call(
51+ IR. Value[rank_placeholder. mlir_data];
52+ fn= sym_attr,
53+ result_0= [IR. TensorType(Int[], IR. Type(Cint))],
54+ location,
55+ output_operand_aliases,
56+ ),
57+ )
58+ return TracedRNumber{Cint}((), res)
10959end
11060
11161function comm_size(comm; location= mlir_stacktrace(" mpi.comm_size" , @__FILE__, @__LINE__))
0 commit comments