Skip to content

Commit 6c4e945

Browse files
committed
issue with ext loading
1 parent 7d9ee6a commit 6c4e945

File tree

3 files changed

+231
-53
lines changed

3 files changed

+231
-53
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
3434
[weakdeps]
3535
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3636
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
37+
ConservativeRegridding = "b4b54eb9-2b73-55e3-98c1-00e3baba4ed5"
3738

3839
[extensions]
3940
ClimaCoreCUDAExt = "CUDA"
4041
KrylovExt = "Krylov"
42+
ClimaCoreConservativeRegriddingExt = "ConservativeRegridding"
4143

4244
[compat]
4345
Adapt = "3.2.0, 4"
@@ -50,6 +52,7 @@ BlockArrays = "1"
5052
CUDA = "5.5"
5153
ClimaComms = "0.6.2"
5254
ClimaCoreMakie = "0.4.6"
55+
ConservativeRegridding = "0.1"
5356
CountFlops = "0.1"
5457
CubedSphere = "0.2, 0.3"
5558
DataStructures = "0.18.13, 0.19"
@@ -90,6 +93,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
9093
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
9194
AssociatedLegendrePolynomials = "2119f1ac-fb78-50f5-8cc0-dda848ebdb19"
9295
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
96+
ConservativeRegridding = "b4b54eb9-2b73-55e3-98c1-00e3baba4ed5"
9397
CountFlops = "1db9610d-79e1-487a-8d40-77f3295c7593"
9498
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
9599
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
@@ -105,4 +109,4 @@ TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
105109
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
106110

107111
[targets]
108-
test = ["Aqua", "ArgParse", "AssociatedLegendrePolynomials", "BenchmarkTools", "CountFlops", "Dates", "FastBroadcast", "Krylov", "JET", "LazyBroadcast", "Logging", "PrettyTables", "Random", "SafeTestsets", "StatsBase", "TerminalLoggers", "Test"]
112+
test = ["Aqua", "ArgParse", "AssociatedLegendrePolynomials", "BenchmarkTools", "ConservativeRegridding", "CountFlops", "Dates", "FastBroadcast", "Krylov", "JET", "LazyBroadcast", "Logging", "PrettyTables", "Random", "SafeTestsets", "StatsBase", "TerminalLoggers", "Test"]

ext/ClimaCoreConservativeRegriddingExt.jl

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
module ClimaCoreConservativeRegriddingExt
22

3+
using ConservativeRegridding, ClimaCore
4+
import ConservativeRegridding: Regridder, regrid!
5+
import ClimaCore: Spaces, Meshes, Quadratures, Fields, RecursiveApply, Spaces
6+
7+
export get_element_vertices, integrate_each_element, get_value_per_element!,
8+
set_value_per_element!, Regridder, regrid!
9+
310
"""
411
get_element_vertices(space::SpectralElementSpace2D)
512
@@ -32,7 +39,7 @@ function get_element_vertices(space)
3239
# Put each polygon into a vector, with the first coordinate pair repeated at the end
3340
vertices = collect(Iterators.partition(vertex_coords, 5))
3441

35-
# Check for zero area polygons
42+
# Check for zero area polygons (all latitude or longitude values are the same)
3643
for polygon in vertices
3744
if allequal(first.(polygon)) || allequal(last.(polygon))
3845
@error "Zero area polygon found in vertices" polygon
@@ -82,9 +89,13 @@ to the number of elements in the space.
8289
Here `ones_field` is a field on the same space as `field` with all
8390
values set to 1.
8491
"""
85-
function get_value_per_element!(value_per_element, field, ones_field)
86-
integral_each_element = Remapping.integrate_each_element(field)
87-
area_each_element = Remapping.integrate_each_element(ones_field)
92+
function get_value_per_element!(
93+
value_per_element,
94+
field,
95+
ones_field,
96+
)
97+
integral_each_element = integrate_each_element(field)
98+
area_each_element = integrate_each_element(ones_field)
8899
value_per_element .= integral_each_element ./ area_each_element
89100
return nothing
90101
end
@@ -117,5 +128,101 @@ function set_value_per_element!(field, value_per_element)
117128
return field
118129
end
119130

131+
"""
132+
Regridder(dst_space, src_space; kwargs...)
133+
134+
Create a regridder between two ClimaCore Spaces.
135+
This works by finding the vertices of the elements of the source and
136+
destination spaces, and then computing the areas of their intersections.
137+
138+
This is currently only defined for 2D spaces, but could be extended to
139+
3D spaces.
140+
141+
This function is intended to be used with the finite volume approximation
142+
of the ClimaCore spectral element space.
143+
"""
144+
function Regridder(
145+
dst_space::Spaces.SpectralElementSpace2D,
146+
src_space::Spaces.SpectralElementSpace2D;
147+
kwargs...,
148+
)
149+
dst_vertices = get_element_vertices(dst_space)
150+
src_vertices = get_element_vertices(src_space)
151+
return ConservativeRegridding.Regridder(dst_vertices, src_vertices; kwargs...)
152+
end
153+
Regridder(
154+
dst_field::Fields.Field,
155+
src_field::Fields.Field;
156+
kwargs...,
157+
) = Regridder(axes(dst_field), axes(src_field); kwargs...)
158+
159+
"""
160+
regrid!(dst_field, regridder_tuple, src_field)
161+
162+
Perform conservative regridding from `src_field` to `dst_field` using a
163+
Regridder object and pre-allocated buffers.
164+
165+
The `regridder_tuple` should be a NamedTuple containing:
166+
- `regridder`: The ConservativeRegridding.Regridder object
167+
- `value_per_element_src`: Pre-allocated buffer for source values
168+
- `value_per_element_dst`: Pre-allocated buffer for destination values
169+
- `ones_src`: Pre-allocated field of ones on the source space
170+
```
171+
"""
172+
function regrid!(
173+
dst_field::Fields.Field,
174+
regridder_tuple::NamedTuple,
175+
src_field::Fields.Field,
176+
)
177+
@assert eltype(dst_field) isa Number && eltype(src_field) isa Number "Regridding is only supported for scalar fields"
178+
@assert eltype(dst_field) == eltype(src_field) "Source and destination fields must have the same element type"
179+
180+
# Use pre-allocated buffers from the regridder tuple
181+
(; value_per_element_src, value_per_element_dst, ones_src, regridder) = regridder_tuple
182+
183+
# Get one value per element in the source field, equal to the quadrature-weighted average of the
184+
# values at nodes of the element
185+
get_value_per_element!(value_per_element_src, src_field, ones_src)
186+
187+
# Perform the regridding
188+
ConservativeRegridding.regrid!(value_per_element_dst, regridder, value_per_element_src)
189+
190+
# Now that we have our regridded vector, put it onto a field on the second space
191+
set_value_per_element!(dst_field, value_per_element_dst)
192+
return nothing
193+
end
194+
195+
"""
196+
regrid!(dst_field, regridder, src_field)
197+
198+
Perform conservative regridding from `src_field` to `dst_field` using a
199+
Regridder object.
200+
201+
This is a convenience function that allocates the buffers for you.
202+
Note that this is not efficient for repeated regridding with the same regridder,
203+
but it may be helpful for one-off regriddings or testing/debugging.
204+
"""
205+
function regrid!(
206+
dst_field::Fields.Field,
207+
regridder::ConservativeRegridding.Regridder,
208+
src_field::Fields.Field,
209+
)
210+
# Allocate space for the buffers
211+
value_per_element_src =
212+
zeros(Float64, Meshes.nelements(axes(src_field).grid.topology.mesh))
213+
value_per_element_dst =
214+
zeros(Float64, Meshes.nelements(axes(dst_field).grid.topology.mesh))
215+
ones_src = ones(axes(src_field))
216+
regridder_tuple = (;
217+
regridder,
218+
value_per_element_src,
219+
value_per_element_dst,
220+
ones_src,
221+
)
222+
223+
# Perform the regridding
224+
regrid!(dst_field, regridder_tuple, src_field)
225+
return nothing
226+
end
120227

121228
end
Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,128 @@
1+
using Test
2+
using ClimaCore
13
using ClimaCore:
24
CommonSpaces, Remapping, Fields, Spaces, RecursiveApply, Meshes, Quadratures
35
using ConservativeRegridding
46

5-
space1 = CommonSpaces.CubedSphereSpace(;
7+
# Extensions load lazily in Julia - they need to be explicitly triggered
8+
# This ensures the extension is loaded and its methods are available
9+
const ClimaConservativeRegrid =
10+
Base.get_extension(ClimaCore, :ClimaCoreConservativeRegriddingExt)
11+
@assert !isnothing(ClimaConservativeRegrid) "ClimaCoreConservativeRegriddingExt extension not loaded"
12+
13+
const src_space = CommonSpaces.CubedSphereSpace(;
614
radius = 10,
715
n_quad_points = 3,
816
h_elem = 8,
917
)
10-
space2 = CommonSpaces.CubedSphereSpace(;
18+
const dst_space = CommonSpaces.CubedSphereSpace(;
1119
radius = 10,
1220
n_quad_points = 4,
1321
h_elem = 6,
1422
)
1523

16-
vertices1 = Remapping.get_element_vertices(space1)
17-
vertices2 = Remapping.get_element_vertices(space2)
18-
19-
# Pass in destination vertices first, source vertices second
20-
# TODO open issue in CR.jl about ordering of inputs source/dest
21-
regridder_1_to_2 = ConservativeRegridding.Regridder(vertices2, vertices1)
22-
regridder_2_to_1 = ConservativeRegridding.Regridder(vertices1, vertices2)
23-
24-
# Define a field on the first space, to use as our source field
25-
field1 = Fields.coordinate_field(space1).lat
26-
ones_field1 = Fields.ones(space1)
27-
28-
# Check that integrating over each element and summing gives the same result as integrating over the whole domain
29-
@assert isapprox(sum(Remapping.integrate_each_element(field1)), sum(field1), atol = 1e-12)
30-
# Check that integrating 1 over each element and summing gives the same result as integrating 1 over the whole domain
31-
@assert sum(Remapping.integrate_each_element(ones_field1)) sum(ones_field1)
32-
33-
# Get one value per element in the field, equal to the average of the values at nodes of the element
34-
value_per_element1 = zeros(Float64, Meshes.nelements(space1.grid.topology.mesh))
35-
Remapping.get_value_per_element!(value_per_element1, field1, ones_field1)
36-
37-
# Allocate a vector with length equal to the number of elements in the target space
38-
value_per_element2 = zeros(Float64, Meshes.nelements(space2.grid.topology.mesh))
39-
ConservativeRegridding.regrid!(value_per_element2, regridder_1_to_2, value_per_element1)
40-
41-
# Now that we have our regridded vector, put it onto a field on the second space
42-
field2 = Fields.zeros(space2)
43-
Remapping.set_value_per_element!(field2, value_per_element2)
44-
field1_one_value_per_element = Fields.zeros(space1)
45-
Remapping.set_value_per_element!(field1_one_value_per_element, value_per_element1)
46-
47-
# # Plot the fields
48-
# using ClimaCoreMakie
49-
# using GLMakie
50-
# fig = ClimaCoreMakie.fieldheatmap(field1)
51-
# save("field1.png", fig)
52-
# fig = ClimaCoreMakie.fieldheatmap(field1_one_value_per_element)
53-
# save("field1_one_value_per_element.png", fig)
54-
# fig = ClimaCoreMakie.fieldheatmap(field2)
55-
# save("field2.png", fig)
56-
57-
# Check the conservation error
58-
abs_error = abs(sum(field1) - sum(field2))
59-
@assert abs_error < 1e-12
60-
abs_error_one_value_per_element = abs(sum(field1_one_value_per_element) - sum(field2))
61-
@assert abs_error_one_value_per_element < 2e-12
24+
@testset "test get_element_vertices" begin
25+
vertices = ClimaConservativeRegrid.get_element_vertices(src_space)
26+
@test length(vertices) == Meshes.nelements(src_space.grid.topology.mesh)
27+
28+
# Check that there are 5 vertices per element (quadrilaterals with repeated first vertex)
29+
@test all(length(vertex) == 5 for vertex in vertices)
30+
@test all(vertex[1] == vertex[5] for vertex in vertices)
31+
end
32+
33+
@testset "test integrate_each_element" begin
34+
# Test integrating a field of ones
35+
ones_field = Fields.ones(src_space)
36+
integral_each_element =
37+
ClimaConservativeRegrid.integrate_each_element(ones_field)
38+
@test isapprox(sum(integral_each_element), sum(ones_field), atol = 1e-11)
39+
40+
# Test integrating a field of latitude
41+
field = Fields.coordinate_field(src_space).lat
42+
integral_each_element = ClimaConservativeRegrid.integrate_each_element(field)
43+
@test isapprox(sum(integral_each_element), sum(field), atol = 1e-12)
44+
end
45+
46+
@testset "test get_value_per_element!" begin
47+
field = Fields.coordinate_field(src_space).lat
48+
ones_field = Fields.ones(src_space)
49+
value_per_element = zeros(Float64, Meshes.nelements(src_space.grid.topology.mesh))
50+
ClimaConservativeRegrid.get_value_per_element!(
51+
value_per_element,
52+
field,
53+
ones_field,
54+
)
55+
56+
@test isapprox(sum(value_per_element), sum(field), atol = 1e-12)
57+
end
58+
59+
@testset "test set_value_per_element!" begin
60+
field = Fields.coordinate_field(src_space).lat
61+
value_per_element = ones(Float64, Meshes.nelements(src_space.grid.topology.mesh))
62+
ClimaConservativeRegrid.set_value_per_element!(field, value_per_element)
63+
64+
@test isapprox(sum(field), sum(value_per_element), atol = 1e-12)
65+
@test all(field .== 1.0)
66+
end
67+
68+
@testset "test Regridder constructor" begin
69+
regridder = ClimaConservativeRegrid.Regridder(dst_space, src_space)
70+
@test regridder isa ConservativeRegridding.Regridder
71+
end
72+
73+
@testset "test regrid!" begin
74+
src_field = Fields.coordinate_field(src_space).lat
75+
dst_field = Fields.zeros(dst_space)
76+
77+
# Test regrid! without pre-allocated buffers
78+
regridder = ClimaConservativeRegrid.Regridder(dst_space, src_space)
79+
ClimaConservativeRegrid.regrid!(dst_field, regridder, src_field)
80+
@test isapprox(sum(dst_field), sum(src_field), atol = 1e-12)
81+
82+
# Test regrid! with pre-allocated buffers
83+
value_per_element_src = zeros(Float64, Meshes.nelements(src_space.grid.topology.mesh))
84+
value_per_element_dst = zeros(Float64, Meshes.nelements(dst_space.grid.topology.mesh))
85+
ones_src = ones(src_space)
86+
regridder_tuple = (;
87+
regridder,
88+
value_per_element_src,
89+
value_per_element_dst,
90+
ones_src,
91+
)
92+
ClimaConservativeRegrid.regrid!(dst_field, regridder_tuple, src_field)
93+
@test isapprox(sum(dst_field), sum(src_field), atol = 1e-12)
94+
end
95+
96+
@testset "test regrid! onto the same space" begin
97+
src_field = Fields.coordinate_field(src_space).lat
98+
dst_field = Fields.zeros(src_space)
99+
100+
# Test regrid! without pre-allocated buffers
101+
regridder = ClimaConservativeRegrid.Regridder(src_space, src_space)
102+
ClimaConservativeRegrid.regrid!(dst_field, regridder, src_field)
103+
@test isapprox(sum(dst_field), sum(src_field), atol = 1e-12)
104+
end
105+
106+
@testset "test regrid! of a constant field" begin
107+
src_field = ones(src_space)
108+
dst_field = Fields.zeros(src_space)
109+
110+
# Test regrid! without pre-allocated buffers
111+
regridder = ClimaConservativeRegrid.Regridder(src_space, src_space)
112+
ClimaConservativeRegrid.regrid!(dst_field, regridder, src_field)
113+
@test isapprox(sum(dst_field), sum(src_field), atol = 1e-12)
114+
end
115+
116+
@testset "test regrid! from source to destination and back" begin
117+
src_field = Fields.coordinate_field(src_space).lat
118+
dst_field = Fields.zeros(dst_space)
119+
120+
# Regrid from source to destination
121+
regridder = ClimaConservativeRegrid.Regridder(dst_space, src_space)
122+
ClimaConservativeRegrid.regrid!(dst_field, regridder, src_field)
123+
@test isapprox(sum(dst_field), sum(src_field), atol = 1e-12)
124+
125+
# Regrid from destination to source using the transpose of the regridder
126+
ClimaConservativeRegrid.regrid!(src_field, transpose(regridder), dst_field)
127+
@test isapprox(sum(src_field), sum(dst_field), atol = 1e-12)
128+
end

0 commit comments

Comments
 (0)