Skip to content

Commit bfd99c2

Browse files
committed
Port map special handling
use fieldcount taylor bundle not taylor tangent
1 parent 50383b6 commit bfd99c2

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/stage1/forward.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,9 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
194194
end
195195
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
196196

197-
#==
198-
# TODO port this to TaylorBundle over composite structure
199-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
200-
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
197+
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
198+
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
201199
end
202-
==#
203200

204201
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
205202
# TODO: This could do an inplace map! to avoid the extra rebundling

src/tangent.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
228228
tb.tangent.coeffs[count_ones(tti.i)]
229229
end
230230

231+
"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
232+
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
233+
return ntuple(fieldcount(B)) do field_ii
234+
the_primal = primal(r)[field_ii]
235+
the_partials = ntuple(N) do order_ii
236+
partial(r, order_ii)[field_ii]
237+
end
238+
return TaylorBundle{N}(the_primal, the_partials)
239+
end
240+
end
241+
242+
231243
function truncate(tt::TaylorTangent, order::Val{N}) where {N}
232244
TaylorTangent(tt.coeffs[1:N])
233245
end

0 commit comments

Comments
 (0)