|
20 | 20 | Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}()
|
21 | 21 | Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}()
|
22 | 22 |
|
23 |
| -function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}, ::Type{ElType}) where {ElType} |
24 |
| - DA = find_darray(bc) |
25 |
| - DArray(I -> Array{ElType}(undef, map(length,I)), DA) |
26 |
| -end |
27 |
| - |
28 |
| -"`DA = find_darray(As)` returns the first DArray among the arguments." |
29 |
| -find_darray(bc::Base.Broadcast.Broadcasted) = find_darray(bc.args) |
30 |
| -find_darray(args::Tuple) = find_darray(find_darray(args[1]), Base.tail(args)) |
31 |
| -find_darray(x) = x |
32 |
| -find_darray(a::DArray, rest) = a |
33 |
| -find_darray(::Any, rest) = find_darray(rest) |
34 |
| - |
35 |
| -function Base.copyto!(dest::DArray, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) |
36 |
| - @sync for p in procs(dest) |
37 |
| - @async remotecall_fetch(p) do |
38 |
| - copyto!(localpart(dest), rewrite_local(bc)) |
39 |
| - end |
| 23 | +function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) |
| 24 | + T = Base.Broadcast.combine_eltypes(bc.f, bc.args) |
| 25 | + shape = Base.Broadcast.combine_axes(bc.args...) |
| 26 | + iter = Base.CartesianIndices(shape) |
| 27 | + D = DArray(map(length, shape)) do I |
| 28 | + A = map(bc.args) do a |
| 29 | + if isa(a, Union{Number,Ref}) |
| 30 | + return a |
| 31 | + else |
| 32 | + return localtype(a)( |
| 33 | + a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...] |
| 34 | + ) |
| 35 | + end |
| 36 | + end |
| 37 | + broadcast(bc.f, A...) |
40 | 38 | end
|
41 |
| - dest |
| 39 | + return D |
42 | 40 | end
|
43 | 41 |
|
44 |
| -""" |
45 |
| -Transform a Broadcasted{Broadcast.ArrayStyle{DArray}} object into an equivalent |
46 |
| -Broadcasted{Broadcast.DefaultArrayStyle} object for the localparts. |
47 |
| -""" |
48 |
| -rewrite_local(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) = Broadcast.broadcasted(bc.f, rewrite_local(bc.args)...) |
49 |
| -rewrite_local(args::Tuple) = map(rewrite_local, args) |
50 |
| -rewrite_local(a::DArray) = localpart(a) |
51 |
| -rewrite_local(x) = x |
52 |
| - |
53 |
| - |
54 | 42 | function Base.reduce(f, d::DArray)
|
55 | 43 | results = asyncmap(procs(d)) do p
|
56 | 44 | remotecall_fetch(p) do
|
@@ -128,6 +116,7 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray)
|
128 | 116 | return mapreducedim_between!(identity, op, R, B, region)
|
129 | 117 | end
|
130 | 118 |
|
| 119 | +## Some special cases |
131 | 120 | function Base._all(f, A::DArray, ::Colon)
|
132 | 121 | B = asyncmap(procs(A)) do p
|
133 | 122 | remotecall_fetch(p) do
|
@@ -171,6 +160,8 @@ function Base.extrema(d::DArray)
|
171 | 160 | return reduce((t,s) -> (min(t[1], s[1]), max(t[2], s[2])), r)
|
172 | 161 | end
|
173 | 162 |
|
| 163 | +Statistics._mean(A::DArray, region) = sum(A, dims = region) ./ prod((size(A, i) for i in region)) |
| 164 | + |
174 | 165 | # Unary vector functions
|
175 | 166 | (-)(D::DArray) = map(-, D)
|
176 | 167 |
|
|
0 commit comments