Skip to content

Commit b36f525

Browse files
authored
feat: High-Level IFRT dispatches (EnzymeAD#828)
* feat: env var to control runtime * feat: skeleton impl of IFRT Array & Number types * feat: working single device execution for IFRT * fix: ambiguity * fix: missing runtime * fix: avoid using args... * ci: run tests with IFRT runtime * fix: use preferences + alias ConcreteRArray and concreteRNumber * feat: more concreteifrt* dispatches * fix: update tests and more coverage * fix: python number type * ci: temporarily disable precompilation * fix: check for no null client * fix: more sharding + precompile fix * fix: copy data over for sharding * test: fix new tests * fix: add check * ci: reduce localjll runners * fix: IFRT un-sharded inputs with partially sharded inputs * feat: non-divisible dims for IFRT * test: remove specialized test * feat: distributed workflow working * ci: fix test call * fix: try disabling precompilation in aarch64 * ci: downgrade testing for ifrt * fix: remove convert dispatch * fix: KA backend * chore: test precompilation on aarch64 * ci: properly use preferences * chore: run formatter * ci: dont overwrite preferences
1 parent d35aa7e commit b36f525

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1599
-669
lines changed

.buildkite/pipeline.yml

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
steps:
22
- group: ":test_tube: Tests"
33
steps:
4-
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
4+
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}"
55
matrix:
66
setup:
77
version:
@@ -10,6 +10,9 @@ steps:
1010
- core
1111
- neural_networks
1212
- integration
13+
runtime:
14+
- "PJRT"
15+
- "IFRT"
1316
plugins:
1417
- JuliaCI/julia#v1:
1518
version: "{{matrix.version}}"
@@ -20,6 +23,13 @@ steps:
2023
- ext
2124
- lib/ReactantCore/src
2225
commands: |
26+
touch LocalPreferences.toml
27+
28+
echo "[Reactant]" >> LocalPreferences.toml
29+
echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml
30+
31+
cat LocalPreferences.toml
32+
2333
julia --project=. -e 'println("--- :julia: Instantiating project")
2434
using Pkg
2535
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'
@@ -33,6 +43,7 @@ steps:
3343
env:
3444
REACTANT_TEST_GROUP: "{{matrix.group}}"
3545
CUDA_VISIBLE_DEVICES: 0
46+
JULIA_DEBUG: "Reactant,Reactant_jll"
3647
if: build.message !~ /\[skip tests\]/
3748
timeout_in_minutes: 120
3849

.github/workflows/CI-localjll.yml

+32-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ jobs:
6262
julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()'
6363
julia --color=yes --project=deps deps/build_local.jl
6464
cp LocalPreferences.toml test/
65+
- name: "Setup Runtime Preferences"
66+
run: |
67+
import Pkg
68+
Pkg.Registry.update()
69+
Pkg.instantiate()
70+
using Preferences
71+
Preferences.set_preferences!("Reactant", "xla_runtime" => "PJRT"; force=true)
72+
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
73+
env:
74+
JULIA_PKG_PRECOMPILE_AUTO: 0
6575
- name: "Install Dependencies"
6676
run: |
6777
import Pkg
@@ -77,16 +87,36 @@ jobs:
7787
if: ${{ matrix.version == '1.10' }}
7888
env:
7989
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
80-
- name: "Run Tests"
90+
- name: "Run Tests: PJRT"
91+
run: |
92+
import Pkg
93+
Pkg.Registry.update()
94+
Pkg.test(; coverage="user")
95+
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
96+
env:
97+
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
98+
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
99+
JULIA_DEBUG: "Reactant,Reactant_jll"
100+
- name: "Setup Runtime Preferences"
101+
run: |
102+
import Pkg
103+
Pkg.Registry.update()
104+
Pkg.instantiate()
105+
using Preferences
106+
Preferences.set_preferences!("Reactant", "xla_runtime" => "IFRT"; force=true)
107+
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
108+
env:
109+
JULIA_PKG_PRECOMPILE_AUTO: 0
110+
- name: "Run Tests: IFRT"
81111
run: |
82112
import Pkg
83113
Pkg.Registry.update()
84114
Pkg.test(; coverage="user")
85115
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
86-
id: run_tests
87116
env:
88117
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
89118
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
119+
JULIA_DEBUG: "Reactant,Reactant_jll"
90120
- uses: julia-actions/julia-processcoverage@v1
91121
- uses: codecov/codecov-action@v5
92122
with:

.github/workflows/CI.yml

+16-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ concurrency:
3131
jobs:
3232
test:
3333
timeout-minutes: 90
34-
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
34+
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.runtime }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
3535
runs-on: ${{ matrix.os }}
3636
strategy:
3737
fail-fast: false
@@ -51,21 +51,27 @@ jobs:
5151
- core
5252
- neural_networks
5353
- integration
54+
runtime:
55+
- "PJRT"
56+
- "IFRT"
5457
assertions:
5558
- false
5659
include:
5760
- os: ubuntu-24.04
5861
version: '1.10'
5962
assertions: true
6063
test_group: core
64+
runtime: "PJRT"
6165
- os: ubuntu-24.04
6266
version: '1.10'
6367
assertions: true
6468
test_group: neural_networks
69+
runtime: "PJRT"
6570
- os: ubuntu-24.04
6671
version: '1.10'
6772
assertions: true
6873
test_group: integration
74+
runtime: "PJRT"
6975
# - os: ubuntu-24.04
7076
# libReactant: packaged
7177
# version: '1.10'
@@ -97,6 +103,14 @@ jobs:
97103
sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum
98104
make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0
99105
echo $PWD/julia/usr/bin >> $GITHUB_PATH
106+
- name: "Setup Runtime Preferences"
107+
uses: "DamianReeves/write-file-action@master"
108+
with:
109+
path: "LocalPreferences.toml"
110+
write-mode: "overwrite"
111+
contents: |
112+
[Reactant]
113+
xla_runtime = "${{ matrix.runtime }}"
100114
- name: "Install Dependencies"
101115
run: |
102116
import Pkg
@@ -124,6 +138,7 @@ jobs:
124138
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
125139
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
126140
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
141+
JULIA_DEBUG: "Reactant,Reactant_jll"
127142
- uses: julia-actions/julia-processcoverage@v1
128143
- uses: codecov/codecov-action@v5
129144
with:

.github/workflows/downgrade.yml

+12
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ jobs:
3838
- core
3939
- neural_networks
4040
- integration
41+
runtime:
42+
- PJRT
43+
- IFRT
4144
steps:
4245
- uses: actions/checkout@v4
4346
- uses: julia-actions/setup-julia@v2
@@ -47,6 +50,14 @@ jobs:
4750
- uses: julia-actions/julia-downgrade-compat@v1
4851
with:
4952
skip: "ReactantCore"
53+
- name: "Setup Runtime Preferences"
54+
uses: "DamianReeves/write-file-action@master"
55+
with:
56+
path: "LocalPreferences.toml"
57+
write-mode: "overwrite"
58+
contents: |
59+
[Reactant]
60+
xla_runtime = "${{ matrix.runtime }}"
5061
- name: "Install Dependencies and Run Tests"
5162
run: |
5263
import Pkg
@@ -64,6 +75,7 @@ jobs:
6475
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
6576
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
6677
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
78+
JULIA_DEBUG: "Reactant,Reactant_jll"
6779
- uses: julia-actions/julia-processcoverage@v1
6880
- uses: codecov/codecov-action@v5
6981
with:

deps/ReactantExtra/API.cpp

+29-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
#include "mlir/Dialect/SCF/IR/SCF.h"
3232
#include "mlir/Dialect/Transform/Transforms/Passes.h"
3333
#include "mlir/InitAllPasses.h"
34+
#include "mlir/Parser/Parser.h"
3435
#include "mlir/Pass/PassRegistry.h"
3536
#include "mlir/Transforms/Passes.h"
36-
#include "mlir/Parser/Parser.h"
3737
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
3838
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
3939
#include "src/enzyme_ad/jax/Passes/Passes.h"
@@ -230,7 +230,7 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
230230
extern "C" MlirOperation mlirOperationParse(MlirContext ctx,
231231
MlirStringRef code) {
232232
ParserConfig config(unwrap(ctx));
233-
OwningOpRef<Operation*> owning_op = parseSourceString(unwrap(code), config);
233+
OwningOpRef<Operation *> owning_op = parseSourceString(unwrap(code), config);
234234
if (!owning_op)
235235
return MlirOperation{nullptr};
236236
return MlirOperation{owning_op.release()};
@@ -1964,6 +1964,33 @@ extern "C" void ifrt_sharding_to_device_list(
19641964
}
19651965
}
19661966

1967+
extern "C" void ifrt_sharding_to_index_domains(
1968+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding,
1969+
int64_t *array_size_list, int32_t array_size_len,
1970+
int64_t *index_domain_origins, int64_t *index_domain_shapes) {
1971+
std::vector<int64_t> array_size(array_size_len);
1972+
for (int i = 0; i < array_size_len; i++) {
1973+
array_size[i] = array_size_list[i];
1974+
}
1975+
auto array_size_span = absl::MakeSpan(array_size);
1976+
auto array_shape = xla::ifrt::Shape(array_size_span);
1977+
1978+
std::vector<ifrt::IndexDomain> index_domains =
1979+
MyValueOrThrow(sharding->obj()->IndexDomains(array_shape));
1980+
1981+
for (int i = 0; i < index_domains.size(); i++) {
1982+
auto index_domain = index_domains[i];
1983+
absl::Span<const int64_t> origin = index_domain.origin().elements();
1984+
absl::Span<const int64_t> shape = index_domain.shape().dims();
1985+
1986+
for (int j = 0; j < origin.size(); j++) {
1987+
auto idx = i * origin.size() + j;
1988+
index_domain_origins[idx] = origin[j];
1989+
index_domain_shapes[idx] = shape[j];
1990+
}
1991+
}
1992+
}
1993+
19671994
#pragma endregion
19681995

19691996
typedef ifrt::Future<> IfRtFutureType;

ext/ReactantArrayInterfaceExt.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ using ArrayInterface: ArrayInterface
44
using Reactant:
55
Reactant,
66
RArray,
7-
ConcretePJRTArray,
8-
ConcretePJRTNumber,
7+
AbstractConcreteNumber,
98
TracedRNumber,
109
TracedRArray,
1110
AnyTracedRArray,
@@ -14,8 +13,9 @@ using Reactant:
1413
ArrayInterface.can_setindex(::Type{<:RArray}) = false
1514
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false
1615

17-
for aType in
18-
(AbstractArray{<:ConcretePJRTNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray)
16+
for aType in (
17+
AbstractArray{<:AbstractConcreteNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray
18+
)
1919
@eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x)
2020
end
2121

ext/ReactantCUDAExt.jl

+45-21
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,10 @@ end
10931093
Base.@nospecializeinfer function Reactant.traced_type_inner(
10941094
@nospecialize(A::Type{<:CuTracedArray}),
10951095
seen,
1096-
mode::Reactant.TraceMode,
1097-
@nospecialize(track_numbers::Type)
1096+
@nospecialize(mode::Reactant.TraceMode),
1097+
@nospecialize(track_numbers::Type),
1098+
@nospecialize(sharding),
1099+
@nospecialize(runtime)
10981100
)
10991101
return A
11001102
end
@@ -1104,28 +1106,34 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
11041106
seen,
11051107
mode::Reactant.TraceMode,
11061108
@nospecialize(track_numbers::Type),
1107-
@nospecialize(sharding)
1109+
@nospecialize(sharding),
1110+
@nospecialize(runtime)
11081111
)
11091112
T = eltype(A)
11101113
N = ndims(A)
11111114
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
1112-
if !Reactant.Sharding.is_sharded(sharding)
1113-
return Reactant.ConcretePJRTArray{T,N,1,Reactant.Sharding.NoShardInfo}
1114-
else
1115-
error("TODO: implement sharding")
1116-
end
1117-
else
1118-
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding)
1119-
if TT === T
1120-
return A
1121-
else
1122-
return Array{
1123-
Reactant.traced_type_inner(
1124-
T, seen, mode, track_numbers, Base.getproperty(sharding, 1)
1125-
),
1115+
if runtime isa Val{:PJRT}
1116+
return Reactant.ConcretePJRTArray{
1117+
T,
11261118
N,
1119+
Reactant.Sharding.ndevices(sharding),
1120+
Reactant.Sharding.shard_type(typeof(sharding), N),
1121+
}
1122+
elseif runtime isa Val{:IFRT}
1123+
return Reactant.ConcreteIFRTArray{
1124+
T,N,Reactant.Sharding.shard_type(typeof(sharding), N)
11271125
}
11281126
end
1127+
error("Unsupported runtime $runtime")
1128+
else
1129+
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding, runtime)
1130+
TT === T && return A
1131+
return Array{
1132+
Reactant.traced_type_inner(
1133+
T, seen, mode, track_numbers, Base.getproperty(sharding, 1), runtime
1134+
),
1135+
N,
1136+
}
11291137
end
11301138
end
11311139

@@ -1136,6 +1144,7 @@ function Reactant.make_tracer(
11361144
mode;
11371145
@nospecialize(track_numbers::Type = Union{}),
11381146
@nospecialize(sharding = Reactant.Sharding.NoSharding()),
1147+
@nospecialize(runtime),
11391148
kwargs...,
11401149
)
11411150
RT = Core.Typeof(prev)
@@ -1145,9 +1154,14 @@ function Reactant.make_tracer(
11451154
return seen[prev]
11461155
end
11471156
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
1148-
return seen[prev] = Reactant.ConcretePJRTArray(Array(prev); sharding)
1157+
if runtime isa Val{:PJRT}
1158+
return seen[prev] = Reactant.ConcretePJRTArray(Array(prev); sharding)
1159+
elseif runtime isa Val{:IFRT}
1160+
return seen[prev] = Reactant.ConcreteIFRTArray(Array(prev); sharding)
1161+
end
1162+
error("Unsupported runtime $runtime")
11491163
end
1150-
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers, sharding)
1164+
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
11511165
if TT === eltype(RT)
11521166
return prev
11531167
end
@@ -1164,6 +1178,7 @@ function Reactant.make_tracer(
11641178
mode;
11651179
track_numbers,
11661180
sharding=Base.getproperty(sharding, I),
1181+
runtime,
11671182
kwargs...,
11681183
)
11691184
if pv !== nv
@@ -1192,7 +1207,15 @@ end
11921207
@static if !Sys.isapple()
11931208
Reactant.PrecompileTools.@setup_workload begin
11941209
Reactant.initialize_dialect()
1195-
client = Reactant.XLA.PJRT.CPUClient(; checkcount=false)
1210+
1211+
if Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT"
1212+
client = Reactant.XLA.PJRT.CPUClient(; checkcount=false)
1213+
elseif Reactant.XLA.REACTANT_XLA_RUNTIME == "IFRT"
1214+
client = Reactant.XLA.IFRT.CPUClient(; checkcount=false)
1215+
else
1216+
error("Unsupported runtime: $(Reactant.XLA.REACTANT_XLA_RUNTIME)")
1217+
end
1218+
11961219
Reactant.PrecompileTools.@compile_workload begin
11971220
@static if Reactant.precompilation_supported() && VERSION != v"1.11.3"
11981221
function square_kernel!(x)
@@ -1205,10 +1228,11 @@ end
12051228
CUDA.@cuda blocks = 1 threads = length(x) square_kernel!(x)
12061229
return nothing
12071230
end
1208-
y = Reactant.ConcretePJRTArray([2.0]; client)
1231+
y = Reactant.ConcreteRArray([2.0]; client)
12091232
Reactant.Compiler.compile_mlir(square!, (y,); optimize=false)
12101233
end
12111234
end
1235+
12121236
Reactant.XLA.free_client(client)
12131237
client.client = C_NULL
12141238
Reactant.deinitialize_dialect()

0 commit comments

Comments
 (0)