#=
Multi-node interpolation correctness check.
Builds the actual AtmosphereModel (reusing moist_baroclinic_wave_model) and tests
that Reactant InterpolateArray nearest-neighbor matches the vanilla CPU kernel for
every prognostic field — including face fields with their different source sizes.
Multi-node: sbatch check_multinode.sh
Single-node: julia --project=.. -O0 check.jl --grid-x 64 --grid-y 64 --grid-z 8
=#
using Dates
@info "Interpolation check starting" now(UTC)
ENV["JULIA_DEBUG"] = "Reactant_jll,Reactant"
using GordonBell25
using GordonBell25: factors, is_distributed_env_present
using Breeze.AtmosphereModels: dynamics_density
using Oceananigans
using Oceananigans.Architectures: ReactantState
const parsed_args = GordonBell25.parse_baroclinic_instability_args(;
grid_x_default = 64,
grid_y_default = 64,
grid_z_default = 8,
)
Oceananigans.defaults.FloatType = GordonBell25.float_type_from_args(parsed_args)
FT = Oceananigans.defaults.FloatType
using Oceananigans.Units
using Printf
using Reactant
using Reactant: Sharding, InterpolateArray, InterpolationType
if !is_distributed_env_present()
using MPI
MPI.Init()
end
GordonBell25.preamble()
GordonBell25.initialize(; single_gpu_per_process=false)
local_arch = ReactantState()
arch = local_arch
Ndev = length(Reactant.devices())
@show Ndev
Rx, Ry = factors(Ndev)
rank = 0
if Ndev > 1
arch = Oceananigans.Distributed(arch; partition = Partition(Rx, Ry, 1))
rank = Reactant.Distributed.local_rank()
end
@info "[$rank] Distributed setup" Ndev Rx Ry
# ─── Grid parameters ─────────────────────────────────────────────────────
H_halo = 4
Tλ = parsed_args["grid-x"] * Rx
Tφ = parsed_args["grid-y"] * Ry
Nz = parsed_args["grid-z"]
Nλ = Tλ - 2H_halo
Nφ = Tφ - 2H_halo
column_height = 30e3
Δt = 0.5
@info "[$rank] Building model (Nλ=$Nλ, Nφ=$Nφ, Nz=$Nz)" now(UTC)
model = GordonBell25.moist_baroclinic_wave_model(arch; Nλ, Nφ, Nz,
H = column_height, Δt,
halo = (H_halo, H_halo, 4),
cloud_formation_τ_relax = 10.0,
initial_conditions_path = nothing)
@info "[$rank] Model built" now(UTC)
# ─── Deterministic source arrays mimicking the IC file ────────────────────
# Source grid is smaller than the target → upscaling, same as the real IC path
Nλ_src, Nφ_src, Nz_src = 32, 16, Nz
# Spatially varying patterns: each field gets a distinct structure so that
# any cross-field mixup or index transposition is immediately visible.
function make_src(nx, ny, nz; scale=FT(1), zscale=FT(1), offset=FT(0))
arr = Array{FT}(undef, nx, ny, nz)
for k in 1:nz, j in 1:ny, i in 1:nx
λ = FT(2π * (i - 1) / nx) # zonal wave
φ = FT(π * (j - 1) / ny - π/2) # meridional: −π/2 … +π/2
z = FT((k - 1) / max(nz - 1, 1)) # vertical: 0 … 1
arr[i, j, k] = offset + scale * (sin(λ) * cos(φ) + zscale * z)
end
return arr
end
# Source sizes match the IC file convention:
# center fields: (Nλ_src, Nφ_src, Nz_src )
# x-face (ρu): (Nλ_src, Nφ_src, Nz_src ) ← same as center in the file
# y-face (ρv): (Nλ_src, Nφ_src+1, Nz_src )
# z-face (ρw): (Nλ_src, Nφ_src, Nz_src+1)
src_ρ = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(1.2), zscale=FT(-0.3), offset=FT(1.0))
src_ρu = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(20), zscale=FT(5), offset=FT(0))
src_ρv = make_src(Nλ_src, Nφ_src + 1, Nz_src; scale=FT(10), zscale=FT(3), offset=FT(0))
src_ρw = make_src(Nλ_src, Nφ_src, Nz_src + 1; scale=FT(0.5), zscale=FT(0.2), offset=FT(0))
src_ρθ = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(15), zscale=FT(-10), offset=FT(330))
src_ρqv = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(1e-4), zscale=FT(-5e-5),offset=FT(3e-4))
src_ρqcl = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(1e-6), zscale=FT(2e-6), offset=FT(1e-6))
src_ρqci = make_src(Nλ_src, Nφ_src, Nz_src; scale=FT(3e-4), zscale=FT(2e-4), offset=FT(5e-4))
field_pairs = [
("ρ", src_ρ, dynamics_density(model.dynamics)),
("ρu", src_ρu, model.momentum.ρu),
("ρv", src_ρv, model.momentum.ρv),
("ρw", src_ρw, model.momentum.ρw),
("ρθ", src_ρθ, model.formulation.potential_temperature_density),
("ρqᵛ", src_ρqv, model.moisture_density),
("ρqᶜˡ", src_ρqcl, model.microphysical_fields[:ρqᶜˡ]),
("ρqᶜⁱ", src_ρqci, model.microphysical_fields[:ρqᶜⁱ]),
]
# ═════════════════════════════════════════════════════════════════════════
# Path 1: Reactant InterpolateArray (as in set_moist_baroclinic_wave_from_file!)
# ═════════════════════════════════════════════════════════════════════════
grid = model.grid
halo = Oceananigans.halo_size(grid)
@info "[$rank] Running Reactant InterpolateArray for all fields..." now(UTC)
reactant_interiors = Dict{String, Array{FT}}()
for (name, src_array, target_field) in field_pairs
target_data = Reactant.ancestor(target_field)
target_size = size(target_data)
target_sharding = target_data.sharding
@info "[$rank] InterpolateArray" name src=size(src_array) dst=target_size halo
result = InterpolateArray(FT.(src_array), target_size, target_sharding,
InterpolationType.Nearest, halo)
full_arr = Array(result)
# Extract interior (strip halos)
Hx, Hy, Hz = halo
interior_size = size(Oceananigans.interior(target_field))
Nx, Ny, Nzf = interior_size
interior = full_arr[(Hx+1):(Hx+Nx), (Hy+1):(Hy+Ny), (Hz+1):(Hz+Nzf)]
reactant_interiors[name] = interior
end
# ═════════════════════════════════════════════════════════════════════════
# Path 2: CPU nearest-neighbor (as in set_moist_baroclinic_wave_from_file_vanilla!)
# ═════════════════════════════════════════════════════════════════════════
@info "[$rank] Running CPU nearest-neighbor for all fields..." now(UTC)
cpu_interiors = Dict{String, Array{FT}}()
for (name, src_array, target_field) in field_pairs
Nx_src, Ny_src, Nz_src_f = size(src_array)
Nx_dst, Ny_dst, Nz_dst = size(Oceananigans.interior(target_field))
cpu_arr = zeros(FT, Nx_dst, Ny_dst, Nz_dst)
for k in 1:Nz_dst, j in 1:Ny_dst, i in 1:Nx_dst
i′ = clamp(Int(ceil(i * Nx_src / Nx_dst)), 1, Nx_src)
j′ = clamp(Int(ceil(j * Ny_src / Ny_dst)), 1, Ny_src)
k′ = clamp(Int(ceil(k * Nz_src_f / Nz_dst)), 1, Nz_src_f)
cpu_arr[i, j, k] = src_array[i′, j′, k′]
end
cpu_interiors[name] = cpu_arr
end
# ═════════════════════════════════════════════════════════════════════════
# Compare
# ═════════════════════════════════════════════════════════════════════════
@info "[$rank] Comparing results..." now(UTC)
all_pass = true
for (name, _, _) in field_pairs
r = reactant_interiors[name]
c = cpu_interiors[name]
diff = abs.(r .- c)
n_mis = count(diff .> 0)
max_diff = maximum(diff)
if n_mis > 0
global all_pass = false
@warn "[$rank] FAIL $name" size=size(r) max_abs_diff=max_diff n_mismatch=n_mis n_total=length(diff)
reported = 0
for k in axes(diff,3), j in axes(diff,2), i in axes(diff,1)
if diff[i,j,k] > 0
reported += 1
@info " mismatch" name idx=(i,j,k) reactant=r[i,j,k] cpu=c[i,j,k] delta=diff[i,j,k]
reported >= 10 && break
end
end
else
@info "[$rank] PASS $name" size=size(r) range=extrema(r)
end
end
if all_pass
@info "[$rank] ALL FIELDS PASS"
else
@error "[$rank] SOME FIELDS FAILED"
end