Skip to content

Interpolation mismatch Oceananigans #2811

@dkytezab

Description

@dkytezab
#=
Multi-node interpolation correctness check.

Builds the actual AtmosphereModel (reusing moist_baroclinic_wave_model) and tests
that Reactant InterpolateArray nearest-neighbor matches the vanilla CPU kernel for
every prognostic field — including face fields with their different source sizes.

  Multi-node:  sbatch check_multinode.sh
  Single-node: julia --project=.. -O0 check.jl --grid-x 64 --grid-y 64 --grid-z 8
=#

using Dates
@info "Interpolation check starting" now(UTC)

ENV["JULIA_DEBUG"] = "Reactant_jll,Reactant"

using GordonBell25
using GordonBell25: factors, is_distributed_env_present
using Breeze.AtmosphereModels: dynamics_density
using Oceananigans
using Oceananigans.Architectures: ReactantState

const parsed_args = GordonBell25.parse_baroclinic_instability_args(;
    grid_x_default = 64,
    grid_y_default = 64,
    grid_z_default = 8,
)

Oceananigans.defaults.FloatType = GordonBell25.float_type_from_args(parsed_args)
FT = Oceananigans.defaults.FloatType

using Oceananigans.Units
using Printf
using Reactant
using Reactant: Sharding, InterpolateArray, InterpolationType

if !is_distributed_env_present()
    using MPI
    MPI.Init()
end

GordonBell25.preamble()
GordonBell25.initialize(; single_gpu_per_process=false)

local_arch = ReactantState()
arch = local_arch

Ndev = length(Reactant.devices())
@show Ndev

Rx, Ry = factors(Ndev)

rank = 0
if Ndev > 1
    arch = Oceananigans.Distributed(arch; partition = Partition(Rx, Ry, 1))
    rank = Reactant.Distributed.local_rank()
end

@info "[$rank] Distributed setup" Ndev Rx Ry

# ─── Grid parameters ─────────────────────────────────────────────────────
H_halo = 4= parsed_args["grid-x"] * Rx
Tφ = parsed_args["grid-y"] * Ry
Nz = parsed_args["grid-z"]
Nλ =- 2H_halo
Nφ =- 2H_halo

column_height = 30e3
Δt = 0.5

@info "[$rank] Building model (Nλ=$Nλ, Nφ=$Nφ, Nz=$Nz)" now(UTC)

model = GordonBell25.moist_baroclinic_wave_model(arch; Nλ, Nφ, Nz,
    H = column_height, Δt,
    halo = (H_halo, H_halo, 4),
    cloud_formation_τ_relax = 10.0,
    initial_conditions_path = nothing)

@info "[$rank] Model built" now(UTC)

# ─── Deterministic source arrays mimicking the IC file ────────────────────
# Source grid is smaller than the target → upscaling, same as the real IC path
Nλ_src, Nφ_src, Nz_src = 32, 16, Nz

# Spatially varying patterns: each field gets a distinct structure so that
# any cross-field mixup or index transposition is immediately visible.
function make_src(nx, ny, nz; scale=FT(1), zscale=FT(1), offset=FT(0))
    arr = Array{FT}(undef, nx, ny, nz)
    for k in 1:nz, j in 1:ny, i in 1:nx
        λ = FT(2π * (i - 1) / nx)              # zonal wave
        φ = FT* (j - 1) / ny - π/2)         # meridional: −π/2 … +π/2
        z = FT((k - 1) / max(nz - 1, 1))       # vertical: 0 … 1
        arr[i, j, k] = offset + scale * (sin(λ) * cos(φ) + zscale * z)
    end
    return arr
end

# Source sizes match the IC file convention:
#   center fields:  (Nλ_src, Nφ_src,   Nz_src  )
#   x-face (ρu):    (Nλ_src, Nφ_src,   Nz_src  )  ← same as center in the file
#   y-face (ρv):    (Nλ_src, Nφ_src+1, Nz_src  )
#   z-face (ρw):    (Nλ_src, Nφ_src,   Nz_src+1)
src_ρ    = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(1.2),  zscale=FT(-0.3), offset=FT(1.0))
src_ρu   = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(20),   zscale=FT(5),    offset=FT(0))
src_ρv   = make_src(Nλ_src, Nφ_src + 1, Nz_src;     scale=FT(10),   zscale=FT(3),    offset=FT(0))
src_ρw   = make_src(Nλ_src, Nφ_src,     Nz_src + 1; scale=FT(0.5),  zscale=FT(0.2),  offset=FT(0))
src_ρθ   = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(15),   zscale=FT(-10),  offset=FT(330))
src_ρqv  = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(1e-4), zscale=FT(-5e-5),offset=FT(3e-4))
src_ρqcl = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(1e-6), zscale=FT(2e-6), offset=FT(1e-6))
src_ρqci = make_src(Nλ_src, Nφ_src,     Nz_src;     scale=FT(3e-4), zscale=FT(2e-4), offset=FT(5e-4))

field_pairs = [
    ("ρ",    src_ρ,    dynamics_density(model.dynamics)),
    ("ρu",   src_ρu,   model.momentum.ρu),
    ("ρv",   src_ρv,   model.momentum.ρv),
    ("ρw",   src_ρw,   model.momentum.ρw),
    ("ρθ",   src_ρθ,   model.formulation.potential_temperature_density),
    ("ρqᵛ",  src_ρqv,  model.moisture_density),
    ("ρqᶜˡ", src_ρqcl, model.microphysical_fields[:ρqᶜˡ]),
    ("ρqᶜⁱ", src_ρqci, model.microphysical_fields[:ρqᶜⁱ]),
]

# ═════════════════════════════════════════════════════════════════════════
# Path 1: Reactant InterpolateArray (as in set_moist_baroclinic_wave_from_file!)
# ═════════════════════════════════════════════════════════════════════════

grid = model.grid
halo = Oceananigans.halo_size(grid)

@info "[$rank] Running Reactant InterpolateArray for all fields..." now(UTC)

reactant_interiors = Dict{String, Array{FT}}()

for (name, src_array, target_field) in field_pairs
    target_data = Reactant.ancestor(target_field)
    target_size = size(target_data)
    target_sharding = target_data.sharding

    @info "[$rank] InterpolateArray" name src=size(src_array) dst=target_size halo

    result = InterpolateArray(FT.(src_array), target_size, target_sharding,
                              InterpolationType.Nearest, halo)

    full_arr = Array(result)

    # Extract interior (strip halos)
    Hx, Hy, Hz = halo
    interior_size = size(Oceananigans.interior(target_field))
    Nx, Ny, Nzf = interior_size
    interior = full_arr[(Hx+1):(Hx+Nx), (Hy+1):(Hy+Ny), (Hz+1):(Hz+Nzf)]
    reactant_interiors[name] = interior
end

# ═════════════════════════════════════════════════════════════════════════
# Path 2: CPU nearest-neighbor (as in set_moist_baroclinic_wave_from_file_vanilla!)
# ═════════════════════════════════════════════════════════════════════════

@info "[$rank] Running CPU nearest-neighbor for all fields..." now(UTC)

cpu_interiors = Dict{String, Array{FT}}()

for (name, src_array, target_field) in field_pairs
    Nx_src, Ny_src, Nz_src_f = size(src_array)
    Nx_dst, Ny_dst, Nz_dst = size(Oceananigans.interior(target_field))

    cpu_arr = zeros(FT, Nx_dst, Ny_dst, Nz_dst)
    for k in 1:Nz_dst, j in 1:Ny_dst, i in 1:Nx_dst
        i′ = clamp(Int(ceil(i * Nx_src / Nx_dst)), 1, Nx_src)
        j′ = clamp(Int(ceil(j * Ny_src / Ny_dst)), 1, Ny_src)
        k′ = clamp(Int(ceil(k * Nz_src_f / Nz_dst)), 1, Nz_src_f)
        cpu_arr[i, j, k] = src_array[i′, j′, k′]
    end
    cpu_interiors[name] = cpu_arr
end

# ═════════════════════════════════════════════════════════════════════════
# Compare
# ═════════════════════════════════════════════════════════════════════════

@info "[$rank] Comparing results..." now(UTC)

all_pass = true

for (name, _, _) in field_pairs
    r = reactant_interiors[name]
    c = cpu_interiors[name]

    diff = abs.(r .- c)
    n_mis = count(diff .> 0)
    max_diff = maximum(diff)

    if n_mis > 0
        global all_pass = false
        @warn "[$rank] FAIL $name" size=size(r) max_abs_diff=max_diff n_mismatch=n_mis n_total=length(diff)
        reported = 0
        for k in axes(diff,3), j in axes(diff,2), i in axes(diff,1)
            if diff[i,j,k] > 0
                reported += 1
                @info "  mismatch" name idx=(i,j,k) reactant=r[i,j,k] cpu=c[i,j,k] delta=diff[i,j,k]
                reported >= 10 && break
            end
        end
    else
        @info "[$rank] PASS $name" size=size(r) range=extrema(r)
    end
end

if all_pass
    @info "[$rank] ALL FIELDS PASS"
else
    @error "[$rank] SOME FIELDS FAILED"
end
┌ Info:   mismatch
│   name = "ρqᶜⁱ"
│   idx = (9, 1, 1)
│   reactant = 0.0006500343f0
│   cpu = 0.0005f0
└   delta = 0.00015003426f0
┌ Info:   mismatch
│   name = "ρθ"
│   idx = (5, 1, 1)
│   reactant = 251.7857f0
│   cpu = 330.0f0
└   delta = 78.214294f0
┌ Info:   mismatch
│   name = "ρqᶜⁱ"
│   idx = (10, 1, 1)
│   reactant = 0.0006500343f0
│   cpu = 0.0005f0
└   delta = 0.00015003426f0
┌ Info:   mismatch
│   name = "ρθ"
│   idx = (6, 1, 1)
│   reactant = 251.7857f0
│   cpu = 330.0f0
└   delta = 78.214294f0
┌ Info:   mismatch
│   name = "ρθ"
│   idx = (7, 1, 1)
│   reactant = 251.7857f0
│   cpu = 330.0f0
└   delta = 78.214294f0
┌ Info:   mismatch
│   name = "ρθ"
│   idx = (8, 1, 1)
│   reactant = 251.7857f0
│   cpu = 330.0f0
└   delta = 78.214294f0
...

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions