@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
40
40
backing (project:: ProjectTo ) = getfield (project, :info )
41
41
42
42
project_type (p:: ProjectTo{T} ) where {T} = T
43
+ project_eltype (p:: ProjectTo{T} ) where {T} = eltype (T)
44
+
45
+ function project_promote_type (projectors)
46
+ T = mapreduce (project_type, promote_type, projectors)
47
+ if T <: Number
48
+ # The point of this function is to make p.element for arrays. Not in use yet!
49
+ return ProjectTo (zero (T))
50
+ else
51
+ return ProjectTo {Any} ()
52
+ end
53
+ end
43
54
44
55
function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
45
56
print (io, " ProjectTo{" )
178
189
# no structure worth re-imposing. Then any array is acceptable as a gradient.
179
190
180
191
# For arrays of numbers, just store one projector:
181
- function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
182
- return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
192
+ function ProjectTo (x:: AbstractArray{T,N} ) where {T<: Number ,N}
193
+ element = _eltype_projectto (T)
194
+ S = project_type (element) # new idea -- for any number, S is enough.
195
+ # Store .element for now too, although it's redundant? Reconstruct from eltype?
196
+ if axes (x) isa NTuple{N,Base. OneTo{Int}}
197
+ return ProjectTo {AbstractArray{S,N}} (; element= element, axes= axes (x))
198
+ else
199
+ # Omitting N prohibits the fast path, and thus won't skip reshape for OffsetArrays, SArrays, etc.
200
+ return ProjectTo {AbstractArray{S}} (; element= element, axes= axes (x))
201
+ end
183
202
end
184
203
ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
185
204
@@ -197,7 +216,7 @@ function ProjectTo(xs::AbstractArray)
197
216
end
198
217
end
199
218
200
- function (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
219
+ function (project:: ProjectTo{<: AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
201
220
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
202
221
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
203
222
dy = if axes (dx) == project. axes
@@ -221,24 +240,33 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
221
240
return dz
222
241
end
223
242
243
+ # Fast path, for arrays of numbers:
244
+ # (::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{T,N}) where {T,N} = (@info "fast 1"; dx)
245
+ (:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S<: T } where {T,N} = dx # (@info "fast 2"; dx)
246
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S,T,N} = (@info " fast 3" ; map (project. element, dx))
247
+
224
248
# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
225
- (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{<:AbstractZero} ) = NoTangent ()
249
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{<:AbstractZero} ) where {T,N} = NoTangent ()
226
250
227
251
# Row vectors aren't acceptable as gradients for 1-row matrices:
228
- function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
252
+ # function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
253
+ # return project(reshape(vec(dx), 1, :))
254
+ # end
255
+ function (project:: ProjectTo{AbstractArray{T,N}} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) where {T,N}
229
256
return project (reshape (vec (dx), 1 , :))
230
257
end
231
258
232
259
# Zero-dimensional arrays -- these have a habit of going missing,
233
260
# although really Ref() is probably a better structure.
234
- function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
235
- if ! (project. axes isa Tuple{})
236
- throw (DimensionMismatch (
237
- " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
238
- ))
239
- end
240
- return fill (project. element (dx))
241
- end
261
+ # function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
262
+ # if !(project.axes isa Tuple{})
263
+ # throw(DimensionMismatch(
264
+ # "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
265
+ # ))
266
+ # end
267
+ # return fill(project.element(dx))
268
+ # end
269
+ (project:: ProjectTo{AbstractArray{<:Number,0}} )(dx:: Number ) = fill (project. element (dx))
242
270
243
271
# Ref -- works like a zero-array, also allows restoration from a number:
244
272
ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
0 commit comments