Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit e248a56

Browse files
ayushinavavik-pal
andauthored
refactor: structs for NeuralOperators (#23)
* structs for Neural Operators * bug fix * bug fix * fixing fno struct * dispatch for TrainState * removing structs for compact layers * DeepONet : Compact => Container layers * deeponet test bug fix * deeponet fixes * FNO : Compact => Container layers * OperatorKernel : Compact => Container layers * dispatch fixes from review * fix: missing specializations * fix: access AbstractExplicitContainerLayer from LuxCore * fix: stop reexporting Lux * test: AMDGPU tests are no longer broken --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 2f2de89 commit e248a56

10 files changed

+129
-104
lines changed

Project.toml

+3-5
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,17 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1313
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
1414
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16-
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1716
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1817

1918
[compat]
2019
ArgCheck = "2.3.0"
2120
ChainRulesCore = "1.24.0"
2221
ConcreteStructs = "0.2.3"
2322
FFTW = "1.8.0"
24-
Lux = "0.5.62"
25-
LuxCore = "0.1.21"
26-
LuxLib = "0.3.40"
23+
Lux = "0.5.64"
24+
LuxCore = "0.1.24"
25+
LuxLib = "0.3.42"
2726
NNlib = "0.9.21"
2827
Random = "1.10"
29-
Reexport = "1.2.2"
3028
WeightInitializers = "1"
3129
julia = "1.10"

src/NeuralOperators.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@ using ChainRulesCore: ChainRulesCore, NoTangent
55
using ConcreteStructs: @concrete
66
using FFTW: FFTW, irfft, rfft
77
using Lux
8-
using LuxCore: LuxCore, AbstractExplicitLayer
8+
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
99
using LuxLib: batched_matmul
1010
using NNlib: NNlib, batched_adjoint
1111
using Random: Random, AbstractRNG
12-
using Reexport: @reexport
1312

1413
const CRC = ChainRulesCore
1514

16-
@reexport using Lux
17-
1815
include("utils.jl")
1916
include("transform.jl")
2017

src/deeponet.jl

+56-49
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
2-
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3-
branch_activation = identity, trunk_activation = identity)
2+
DeepONet(branch, trunk, additional)
43
5-
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
6-
`trunk` are same.
4+
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
5+
nets output should have the same first dimension.
76
8-
## Keyword arguments:
7+
## Arguments
8+
9+
- `branch`: `Lux` network to be used as branch net.
10+
- `trunk`: `Lux` network to be used as trunk net.
11+
12+
## Keyword Arguments
913
10-
- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
11-
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
12-
- `branch_activation`: activation function for branch net
13-
- `trunk_activation`: activation function for trunk net
1414
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
1515
for embeddings, defaults to `nothing`
1616
@@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
2323
## Example
2424
2525
```jldoctest
26-
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
26+
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
27+
28+
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
29+
30+
julia> deeponet = DeepONet(branch_net, trunk_net);
2731
2832
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
2933
@@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st)))
3539
(10, 5)
3640
```
3741
"""
38-
function DeepONet(;
39-
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
40-
trunk_activation=identity, additional=nothing)
41-
42-
# checks for last dimension size
43-
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
44-
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
45-
work."
46-
47-
branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
48-
for i in 1:(length(branch) - 1)]...)
49-
50-
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
51-
for i in 1:(length(trunk) - 1)]...)
52-
53-
return DeepONet(branch_net, trunk_net; additional)
42+
@concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
43+
branch
44+
trunk
45+
additional
5446
end
5547

56-
"""
57-
DeepONet(branch, trunk)
48+
DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer())
5849

59-
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
60-
nets output should have the same first dimension.
61-
62-
## Arguments
50+
"""
51+
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
52+
branch_activation = identity, trunk_activation = identity)
6353
64-
- `branch`: `Lux` network to be used as branch net.
65-
- `trunk`: `Lux` network to be used as trunk net.
54+
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
55+
`trunk` are same.
6656
67-
## Keyword Arguments
57+
## Keyword arguments:
6858
59+
- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
60+
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
61+
- `branch_activation`: activation function for branch net
62+
- `trunk_activation`: activation function for trunk net
6963
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
7064
for embeddings, defaults to `nothing`
7165
@@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
7872
## Example
7973
8074
```jldoctest
81-
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
82-
83-
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
84-
85-
julia> deeponet = DeepONet(branch_net, trunk_net);
75+
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
8676
8777
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
8878
@@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st)))
9484
(10, 5)
9585
```
9686
"""
97-
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
98-
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
99-
t = trunk(y) # p x N x nb
100-
b = branch(u) # p x u_size... x nb
87+
function DeepONet(;
88+
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
89+
trunk_activation=identity, additional=NoOpLayer())
90+
91+
# checks for last dimension size
92+
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
93+
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
94+
work."
95+
96+
branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
97+
for i in 1:(length(branch) - 1)]...)
98+
99+
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
100+
for i in 1:(length(trunk) - 1)]...)
101+
102+
return DeepONet(branch_net, trunk_net, additional)
103+
end
104+
105+
function (deeponet::DeepONet)(x, ps, st::NamedTuple)
106+
b, st_b = deeponet.branch(x[1], ps.branch, st.branch)
107+
t, st_t = deeponet.trunk(x[2], ps.trunk, st.trunk)
101108

102-
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
103-
amount of nodes in the last layer. Otherwise \
104-
Σᵢ bᵢⱼ tᵢₖ won't work."
109+
@argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \
110+
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
111+
work."
105112

106-
@return __project(b, t, additional)
107-
end
113+
out, st_a = __project(b, t, deeponet.additional, (; ps=ps.additional, st=st.additional))
114+
return out, (branch=st_b, trunk=st_t, additional=st_a)
108115
end

src/fno.jl

+21-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
2727
## Example
2828
2929
```jldoctest
30-
julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,));
30+
julia> fno = FourierNeuralOperator(; σ=gelu, chs=(2, 64, 64, 128, 1), modes=(16,));
3131
3232
julia> ps, st = Lux.setup(Xoshiro(), fno);
3333
@@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st)))
3737
(1, 1024, 5)
3838
```
3939
"""
40-
function FourierNeuralOperator(
41-
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
40+
@concrete struct FourierNeuralOperator <:
41+
AbstractExplicitContainerLayer{(:lifting, :mapping, :project)}
42+
lifting
43+
mapping
44+
project
45+
end
46+
47+
function FourierNeuralOperator(;
48+
σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
4249
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
4350
@argcheck length(chs) 5
4451

@@ -52,9 +59,15 @@ function FourierNeuralOperator(
5259
project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) :
5360
Chain(Dense(map₂, σ), Dense(map₃))
5461

55-
return Chain(; lifting,
56-
mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
57-
for i in 2:(C - 3)]...),
58-
project,
59-
name="FourierNeuralOperator")
62+
mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
63+
for i in 2:(C - 3)]...)
64+
65+
return FourierNeuralOperator(lifting, mapping, project)
66+
end
67+
68+
function (fno::FourierNeuralOperator)(x::AbstractArray, ps, st::NamedTuple)
69+
lift, st_lift = fno.lifting(x, ps.lifting, st.lifting)
70+
mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping)
71+
project, st_project = fno.project(mapping, ps.project, st.project)
72+
return project, (lifting=st_lift, mapping=st_mapping, project=st_project)
6073
end

src/layers.jl

+20-8
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,30 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(
116116
117117
```
118118
"""
119+
@concrete struct OperatorKernel <: AbstractExplicitContainerLayer{(:lin, :conv)}
120+
lin
121+
conv
122+
activation <: Function
123+
end
124+
125+
OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity)
126+
119127
function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
120128
act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false),
121129
kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A}
122130
act = allow_fast_activation ? NNlib.fast_act(act) : act
123-
l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
124-
l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...)
125-
126-
return @compact(; l₁, l₂, activation=act, dispatch=:OperatorKernel) do x::AbstractArray
127-
l₁x = l₁(x)
128-
l₂x = l₂(x)
129-
@return @. activation(l₁x + l₂x)
130-
end
131+
lin = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
132+
conv = OperatorConv(ch, modes, transform; permuted, kwargs...)
133+
134+
return OperatorKernel(lin, conv, act)
135+
end
136+
137+
function (op::OperatorKernel)(x::AbstractArray, ps, st::NamedTuple)
138+
x_conv, st_conv = op.conv(x, ps.conv, st.conv)
139+
x_lin, st_lin = op.lin(x, ps.lin, st.lin)
140+
141+
out = fast_activation!!(op.activation, x_conv .+ x_lin)
142+
return out, (lin=st_lin, conv=st_conv)
131143
end
132144

133145
"""

src/utils.jl

+19-19
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
2-
additional::Nothing) where {T1, T2}
1+
@inline function __project(
2+
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
33
# b : p x nb
44
# t : p x N x nb
55
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
6-
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
6+
return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb
77
end
88

9-
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
10-
additional::Nothing) where {T1, T2}
9+
@inline function __project(
10+
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
1111
# b : p x u x nb
1212
# t : p x N x nb
1313
if size(b, 2) == 1 || size(t, 2) == 1
14-
return sum(b .* t; dims=1) # 1 x N x nb
14+
return sum(b .* t; dims=1), () # 1 x N x nb
1515
else
16-
return batched_matmul(batched_adjoint(b), t) # u x N x b
16+
return batched_matmul(batched_adjoint(b), t), () # u x N x b
1717
end
1818
end
1919

20-
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
21-
additional::Nothing) where {T1, T2, N}
20+
@inline function __project(
21+
b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2, N}
2222
# b : p x u_size x nb
2323
# t : p x N x nb
2424
u_size = size(b)[2:(end - 1)]
@@ -29,34 +29,34 @@ end
2929
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
3030
# p x (1,1,1...) x N x nb
3131

32-
return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
32+
return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb
3333
end
3434

35-
@inline function __project(
36-
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
35+
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
36+
additional::T, params) where {T1, T2, T}
3737
# b : p x nb
3838
# t : p x N x nb
3939
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
40-
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
40+
return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
4141
end
4242

43-
@inline function __project(
44-
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
43+
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
44+
additional::T, params) where {T1, T2, T}
4545
# b : p x u x nb
4646
# t : p x N x nb
4747

4848
if size(b, 2) == 1 || size(t, 2) == 1
49-
return additional(b .* t) # p x N x nb => out_dims x N x nb
49+
return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
5050
else
5151
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
5252
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb
5353

54-
return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
54+
return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb
5555
end
5656
end
5757

5858
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
59-
additional::T) where {T1, T2, N, T}
59+
additional::T, params) where {T1, T2, N, T}
6060
# b : p x u_size x nb
6161
# t : p x N x nb
6262
u_size = size(b)[2:(end - 1)]
@@ -67,5 +67,5 @@ end
6767
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
6868
# p x (1,1,1...) x N x nb
6969

70-
return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb
70+
return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
7171
end

test/Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ Documenter = "1.5.0"
2525
ExplicitImports = "1.9.0"
2626
Hwloc = "3.2.0"
2727
InteractiveUtils = "<0.0.1, 1"
28-
Lux = "0.5.62"
29-
LuxCore = "0.1.21"
30-
LuxLib = "0.3.40"
28+
Lux = "0.5.64"
29+
LuxCore = "0.1.24"
30+
LuxLib = "0.3.42"
3131
LuxTestUtils = "1.1.2"
3232
MLDataDevices = "1.0.0"
3333
Optimisers = "0.3.3"

test/fno_tests.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222
@test size(first(fno(x, ps, st))) == setup.y_size
2323

2424
data = [(x, y)]
25-
broken = mode == "AMDGPU"
2625
@test begin
2726
l2, l1 = train!(fno, ps, st, data; epochs=10)
2827
l2 < l1
29-
end broken=broken
28+
end
3029

3130
__f = (x, ps) -> sum(abs2, first(fno(x, ps, st)))
3231
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,

0 commit comments

Comments
 (0)