Skip to content

Commit f65bcc5

Browse files
committed
New feature: Added a deformation inverse prototype
1 parent 7230aac commit f65bcc5

File tree

2 files changed

+88
-17
lines changed

2 files changed

+88
-17
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ If you are brave enouth to try using it, then the following may work (in Julia)
1111

1212

1313
Note that the automated githib actions fail because CUDA drivers are missing, which leads on to several other problems.
14+
Multi-dimensionsional ffts on GPU (CUFFT) can also be [problematic](https://github.com/JuliaGPU/CUDA.jl/issues/119) with older Julia versions (fixed somewhere between Julia 1.7.2 and 1.9.3).
1415

1516
[![Build Status](https://github.com/spm/PushPull.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/spm/PushPull.jl/actions/workflows/CI.yml?query=branch%3Amain)
1617

src/multigrid.jl

+87-17
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,66 @@ Multiplying `v` by the Hessian (`H`).
6060
function Hv!(v::VolType, H::VolType, u::VolType=zero(v))::VolType
6161
@assert(all(dim(H) .== dim(v)))
6262
@assert(all(size(v) .== size(u)))
63-
@assert(size(H,4) == 6)
64-
@assert(size(v,4) == 3)
65-
h11 = view(H,:,:,:,1)
66-
h22 = view(H,:,:,:,2)
67-
h33 = view(H,:,:,:,3)
68-
h12 = view(H,:,:,:,4)
69-
h13 = view(H,:,:,:,5)
70-
h23 = view(H,:,:,:,6)
71-
v1 = view(v,:,:,:,1)
72-
v2 = view(v,:,:,:,2)
73-
v3 = view(v,:,:,:,3)
74-
u1 = view(u,:,:,:,1)
75-
u2 = view(u,:,:,:,2)
76-
u3 = view(u,:,:,:,3)
77-
u1 .+= h11.*v1 .+ h12.*v2 .+ h13.*v3
78-
u2 .+= h12.*v1 .+ h22.*v2 .+ h23.*v3
79-
u3 .+= h13.*v1 .+ h23.*v2 .+ h33.*v3
63+
@assert(ndims(v)==4 && ndims(H)==4)
64+
dv = size(v,4)
65+
dh = size(H,4)
66+
@assert(dh==1 || dh==dv || dh == Int((dv+1)*dv/2))
67+
68+
if false #size(v,4) == 3 # Special case
69+
v1 = view(v,:,:,:,1)
70+
v2 = view(v,:,:,:,2)
71+
v3 = view(v,:,:,:,3)
72+
u1 = view(u,:,:,:,1)
73+
u2 = view(u,:,:,:,2)
74+
u3 = view(u,:,:,:,3)
75+
if size(H,4) >= 3
76+
h11 = view(H,:,:,:,1)
77+
h22 = view(H,:,:,:,2)
78+
h33 = view(H,:,:,:,3)
79+
if size(H,4) == 6
80+
h12 = view(H,:,:,:,4)
81+
h13 = view(H,:,:,:,5)
82+
h23 = view(H,:,:,:,6)
83+
u1 .+= h11.*v1 .+ h12.*v2 .+ h13.*v3
84+
u2 .+= h12.*v1 .+ h22.*v2 .+ h23.*v3
85+
u3 .+= h13.*v1 .+ h23.*v2 .+ h33.*v3
86+
elseif size(H,4) == 3
87+
u1 .+= h11.*v1
88+
u2 .+= h22.*v2
89+
u3 .+= h33.*v3
90+
else
91+
error()
92+
end
93+
elseif size(H,4) == 1
94+
h11 = view(H,:,:,:,1)
95+
u1 .+= h11.*v1
96+
u2 .+= h11.*v2
97+
u3 .+= h11.*v3
98+
else
99+
error()
100+
end
101+
return u
102+
else # General case
103+
if dh==1
104+
h = view(H,:,:,:,1)
105+
for i=1:dv
106+
view(u,:,:,:,i) .+= h.*view(v,:,:,:,i)
107+
end
108+
elseif dh==dv || dh==Int((dv+1)*dv/2)
109+
for i=1:dv
110+
view(u,:,:,:,i) .+= view(H,:,:,:,i).*view(v,:,:,:,i)
111+
end
112+
if dh==Int((dv+1)*dv/2)
113+
ii = dv
114+
for i=1:dv, j=i+1:dv
115+
ii += 1
116+
h = view(H,:,:,:,ii)
117+
view(u,:,:,:,i) .+= h.*view(v,:,:,:,j)
118+
view(u,:,:,:,j) .+= h.*view(v,:,:,:,i)
119+
end
120+
end
121+
end
122+
end
80123
return u
81124
end
82125

@@ -213,3 +256,30 @@ function fcycle!(v::VolType, g::VolType, HL::PyramidType; nit_pre::Integer=4, ni
213256
return v
214257
end
215258

259+
260+
function invert_def(phi::T)::T where T<:VolType
261+
d = size(phi)
262+
if length(d)>4
263+
iphi = zero(phi)
264+
for i in CartesianIndices(d[5:end])
265+
iphi[:,:,:,:,i] .= invert_def(phi[:,:,:,:,i])
266+
end
267+
return iphi
268+
else
269+
Id = id(d[1:3]; gpu=~isa(phi,Array))
270+
sett = Settings(1, [2 1 1;1 2 1;1 1 2], 0)
271+
g = push(Id, phi, d[1:3], sett)
272+
o = typeof(g)(undef,(d[1:3]...,1))
273+
o .= 1
274+
h = push(o, phi, d[1:3], Settings(1,1,0))
275+
g .-= h.*Id
276+
H = typeof(g)(undef,(d[1:3]...,3))
277+
H[:,:,:,1:3].=h
278+
HL = hessian_pyramid(H,[1f0,1f0,1f0],[0., 0.01, 0.1, 0.01])
279+
iphi = zero(g)
280+
vcycle!(iphi, g, HL; nit_pre=2, nit_post=2)
281+
iphi .+= Id
282+
return iphi
283+
end
284+
end
285+

0 commit comments

Comments
 (0)