-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change irim block - add invertible UNET #57
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportBase: 88.11% // Head: 87.94% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #57 +/- ##
==========================================
- Coverage 88.11% 87.94% -0.17%
==========================================
Files 31 32 +1
Lines 2330 2390 +60
==========================================
+ Hits 2053 2102 +49
- Misses 277 288 +11
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to include all the necessary modifications due to this change in one commit.
C::Conv1x1 | ||
RB::Union{ResidualBlock, FluxBlock} | ||
C::AbstractArray{Conv1x1, 1} | ||
RB::AbstractArray{ResidualBlock, 1} | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we removing FluxBlock
from allowed types?
@@ -75,10 +75,18 @@ end | |||
# Constructors | |||
|
|||
# Constructor | |||
function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) | |||
function ResidualBlock(n_in, n_hidden; d=nothing, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear what d
is. Maybe add a docstring.
|
||
# Check if downsampling factor d is defined | ||
if !isnothing(d) | ||
k1 = d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why k1 = s1 = d
? Please add a reference for this choice.
@@ -66,7 +66,7 @@ end | |||
@Flux.functor NetworkLoop | |||
|
|||
# 2D Constructor | |||
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) | |||
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar variable names n_hidden
and and n_hiddens
. Maybe add a docstring explaining these different inputs and think of a more clear variable name.
@@ -0,0 +1,133 @@ | |||
# Invertible network layer from Putzky and Welling (2019): https://arxiv.org/abs/1911.10914 | |||
# Author: Philipp Witte, [email protected] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these information correct?
@@ -64,6 +64,7 @@ include("layers/invertible_layer_hint.jl") | |||
# Invertible network architectures | |||
include("networks/invertible_network_hint_multiscale.jl") | |||
include("networks/invertible_network_irim.jl") # i-RIM: Putzky and Welling (2019) | |||
include("networks/invertible_network_unet.jl") # single loop i-RIM: Putzky and Welling (2019) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implementing https://github.com/pputzky/invertible_rim/blob/master/irim/core/invertible_unet.py?
C = Conv1x1(n_in) | ||
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) | ||
if length(n_hiddens) != length(ds) | ||
throw("Number of downsampling factors in ds must be the same defined hidden channels in n_hidden") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_hiddens
and ds
must have equal length but the provided lengths are .... and ....
else | ||
ΔX, Δθ_C2 = L.C.inverse((ΔX_, X_); set_grad=set_grad)[1:2] | ||
# Initialize layer parameters | ||
!set_grad && (p1 = Array{Parameter, 1}(undef, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to initialize?
end | ||
|
||
set_grad ? (return ΔX, X) : (return ΔX, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1), X) | ||
set_grad ? (return ΔY, Y) : (ΔY, cat(p1, p2; dims=1), Y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to keep the naming convention the same, i.e., this should return ΔX
.
end | ||
|
||
@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) | ||
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing jacobian
tests.
add very simple invertible unet. No gradient as irim, just an input for clear comparison with traditional unets.
take away indexing
-Changing irim block to generate multiple RBs with different dilations and different hidden channels. This is the proper welling implementation.
-This should break some examples. If we are okay with this new block I will go and change all examples to run properly. It is as easy as changing NetworkIRIM(n_in, n_hidden ....)->NetworkIRIM(n_in, [n_hidden], [4];) thus defining a single unet layer with conv dilation 4 which is the current IRIM implementation.
-Add new network invertible unet this is basically a single loop unrolled iteration of irim. Name comes from welling code.
-Directly takes in a precomputed gradient.
-meant for inverse problems where your operator is too expensive to use online.
-Doesnt have aggressive memory savings such as inplace conv1x1 yet but should work well with moderately sized 3D. Will be testing this.