1
1
"""
2
- DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3
- branch_activation = identity, trunk_activation = identity)
2
+ DeepONet(branch, trunk, additional)
4
3
5
- Constructs a DeepONet composed of Dense layers . Make sure the last node of `branch` and
6
- `trunk` are same.
4
+ Constructs a DeepONet from a `branch` and `trunk` architectures . Make sure that both the
5
+ nets output should have the same first dimension .
7
6
8
- ## Keyword arguments:
7
+ ## Arguments
8
+
9
+ - `branch`: `Lux` network to be used as branch net.
10
+ - `trunk`: `Lux` network to be used as trunk net.
11
+
12
+ ## Keyword Arguments
9
13
10
- - `branch`: Tuple of integers containing the number of nodes in each layer for branch net
11
- - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
12
- - `branch_activation`: activation function for branch net
13
- - `trunk_activation`: activation function for trunk net
14
14
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
15
15
for embeddings, defaults to `nothing`
16
16
@@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
23
23
## Example
24
24
25
25
```jldoctest
26
- julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
26
+ julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
27
+
28
+ julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
29
+
30
+ julia> deeponet = DeepONet(branch_net, trunk_net);
27
31
28
32
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
29
33
@@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st)))
35
39
(10, 5)
36
40
```
37
41
"""
38
- function DeepONet (;
39
- branch= (64 , 32 , 32 , 16 ), trunk= (1 , 8 , 8 , 16 ), branch_activation= identity,
40
- trunk_activation= identity, additional= nothing )
41
-
42
- # checks for last dimension size
43
- @argcheck branch[end ]== trunk[end ] " Branch and Trunk net must share the same amount of \
44
- nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
45
- work."
46
-
47
- branch_net = Chain ([Dense (branch[i] => branch[i + 1 ], branch_activation)
48
- for i in 1 : (length (branch) - 1 )]. .. )
49
-
50
- trunk_net = Chain ([Dense (trunk[i] => trunk[i + 1 ], trunk_activation)
51
- for i in 1 : (length (trunk) - 1 )]. .. )
52
-
53
- return DeepONet (branch_net, trunk_net; additional)
42
+ @concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
43
+ branch
44
+ trunk
45
+ additional
54
46
end
55
47
56
- """
57
- DeepONet(branch, trunk)
48
+ DeepONet (branch, trunk) = DeepONet (branch, trunk, NoOpLayer ())
58
49
59
- Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
60
- nets output should have the same first dimension.
61
-
62
- ## Arguments
50
+ """
51
+ DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
52
+ branch_activation = identity, trunk_activation = identity)
63
53
64
- - `branch`: `Lux` network to be used as branch net.
65
- - `trunk`: `Lux` network to be used as trunk net .
54
+ Constructs a DeepONet composed of Dense layers. Make sure the last node of ` branch` and
55
+ `trunk` are same .
66
56
67
- ## Keyword Arguments
57
+ ## Keyword arguments:
68
58
59
+ - `branch`: Tuple of integers containing the number of nodes in each layer for branch net
60
+ - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
61
+ - `branch_activation`: activation function for branch net
62
+ - `trunk_activation`: activation function for trunk net
69
63
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
70
64
for embeddings, defaults to `nothing`
71
65
@@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
78
72
## Example
79
73
80
74
```jldoctest
81
- julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
82
-
83
- julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
84
-
85
- julia> deeponet = DeepONet(branch_net, trunk_net);
75
+ julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
86
76
87
77
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
88
78
@@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st)))
94
84
(10, 5)
95
85
```
96
86
"""
97
- function DeepONet (branch:: L1 , trunk:: L2 ; additional= nothing ) where {L1, L2}
98
- return @compact (; branch, trunk, additional, dispatch= :DeepONet ) do (u, y)
99
- t = trunk (y) # p x N x nb
100
- b = branch (u) # p x u_size... x nb
87
+ function DeepONet (;
88
+ branch= (64 , 32 , 32 , 16 ), trunk= (1 , 8 , 8 , 16 ), branch_activation= identity,
89
+ trunk_activation= identity, additional= NoOpLayer ())
90
+
91
+ # checks for last dimension size
92
+ @argcheck branch[end ]== trunk[end ] " Branch and Trunk net must share the same amount of \
93
+ nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
94
+ work."
95
+
96
+ branch_net = Chain ([Dense (branch[i] => branch[i + 1 ], branch_activation)
97
+ for i in 1 : (length (branch) - 1 )]. .. )
98
+
99
+ trunk_net = Chain ([Dense (trunk[i] => trunk[i + 1 ], trunk_activation)
100
+ for i in 1 : (length (trunk) - 1 )]. .. )
101
+
102
+ return DeepONet (branch_net, trunk_net, additional)
103
+ end
104
+
105
+ function (deeponet:: DeepONet )(x, ps, st:: NamedTuple )
106
+ b, st_b = deeponet. branch (x[1 ], ps. branch, st. branch)
107
+ t, st_t = deeponet. trunk (x[2 ], ps. trunk, st. trunk)
101
108
102
- @argcheck size (t , 1 )== size (b , 1 ) " Branch and Trunk net must share the same \
103
- amount of nodes in the last layer. Otherwise \
104
- Σᵢ bᵢⱼ tᵢₖ won't work."
109
+ @argcheck size (b , 1 )== size (t , 1 ) " Branch and Trunk net must share the same amount of \
110
+ nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
111
+ work."
105
112
106
- @return __project (b, t, additional)
107
- end
113
+ out, st_a = __project (b, t, deeponet . additional, (; ps = ps . additional, st = st . additional) )
114
+ return out, (branch = st_b, trunk = st_t, additional = st_a)
108
115
end
0 commit comments