Skip to content

Comm optimization tests #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.9"
Reactant_jll = "0.0.135"
Reactant_jll = "0.0.137"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
7 changes: 6 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ function optimization_passes(;
"concat_concat_axis_swap",
"concat_multipad",
"concat_concat_to_dus",
# TODO we want to enable but may cause an infinite compile time
# "concat_to_onedim_dusslice",
]

if DUS_SLICE_SIMPLIFY[]
Expand Down Expand Up @@ -1029,11 +1031,14 @@ end

const optimize_comms_passes = (
# rotate handler presently broken (and handled okay presently), disabling for now
"enzyme-hlo-generate-td{patterns=lower_rotate}",
"enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice}",
"transform-interpreter",
"enzyme-hlo-remove-transform",
"optimize-communication",
"enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}",
"transform-interpreter",
"enzyme-hlo-remove-transform",
"optimize-communication",
)

function compile_mlir!(
Expand Down
112 changes: 112 additions & 0 deletions test/optimize_comm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using Reactant, Test

const addressable_devices = Reactant.addressable_devices()

function rotate(x)
y = x[1:100, :]
x[1:924, :] = x[101:1024, :]
x[925:1024, :] = y
return nothing
end
function pad(x)
return Reactant.Ops.pad(x, eltype(x)(0); low=[5, 0], high=[15, 0], interior=[0, 0])
end

function dus(x, y)
x[6:(size(x, 1) - 15), :] = y
return nothing
end

function dus2(x, y)
x[6:(size(x, 1) - 15), 2:11] = y
return nothing
end

if length(addressable_devices) ≥ 8
@testset "Rotate" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

x = reshape(collect(Int, 1:(1024 * 12)), 1024, 12)
rx = Reactant.to_rarray(x; sharding)

hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings rotate(rx))
@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")

rotate(x)
@jit shardy_passes = :to_mhlo_shardings rotate(rx)
@test all(x .== convert(Array, rx))
end

@testset "Pad" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

x = reshape(collect(Int, 1:(1024 * 12)), 1024, 12)
rx = Reactant.to_rarray(x; sharding)

hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings pad(rx))
@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")

# No non reactant version available res = pad(x)
r_res = @jit shardy_passes = :to_mhlo_shardings pad(rx)
# @test all(res .== convert(Array, r_res))
end

@testset "DUS" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

M = 1024

x = reshape(collect(Int, 1:(M * 12)), M, 12)
y = reshape(collect(Int, 1000 * (1:((M - 20) * 12))), M - 20, 12)
rx = Reactant.to_rarray(x; sharding)
ry = Reactant.to_rarray(y; sharding)

hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings dus(rx, ry))
println(hlo)
@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")

dus(x, y)
@jit shardy_passes = :to_mhlo_shardings dus(rx, ry)
@test all(x .== convert(Array, rx))
@test all(y .== convert(Array, ry))
end

@testset "DUS2" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

M = 1024

x = reshape(collect(Int, 1:(M * 12)), M, 12)
y = reshape(collect(Int, 1000 * (1:((M - 20) * 10))), M - 20, 10)
rx = Reactant.to_rarray(x; sharding)
ry = Reactant.to_rarray(y; sharding)

hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings dus2(rx, ry))
@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")

dus2(x, y)
@jit shardy_passes = :to_mhlo_shardings dus2(rx, ry)
@test all(x .== convert(Array, rx))
@test all(y .== convert(Array, ry))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Custom Number Types" include("custom_number_types.jl")
end
@safetestset "Sharding" include("sharding.jl")
@safetestset "Comm Optimization" include("optimize_comm.jl")
@safetestset "Cluster Detection" include("cluster_detector.jl")
@safetestset "Config" include("config.jl")
end
Expand Down
Loading