Skip to content

Commit 721858c

Browse files
authored
Merge pull request #161 from JuliaParallel/an/dist
Fixes the broadcasting and mean with dimension
2 parents f8c26bd + 6b93668 commit 721858c

File tree

6 files changed

+31
-34
lines changed

6 files changed

+31
-34
lines changed

.travis.yml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ os:
44
- osx
55
julia:
66
- 0.7
7+
- 1.0
78
- nightly
89
matrix:
910
# allow_failures:

src/DistributedArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module DistributedArrays
55
using Distributed
66
using Serialization
77
using LinearAlgebra
8+
using Statistics
89

910
import Base: +, -, *, div, mod, rem, &, |, xor
1011
import Base.Callable

src/darray.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,9 @@ function Base.reshape(A::DArray{T,1,S}, d::Dims) where {T,S<:Array}
568568
i2 = CartesianIndices(sztail)[i]
569569
globalidx = [ I[j][i2[j-1]] for j=2:nd ]
570570

571-
a = sub2ind(d, d1offs, globalidx...)
571+
a = LinearIndices(d)[d1offs, globalidx...]
572572

573-
B[:,i] = A[a:(a+nr-1)]
573+
B[:,i] = Array(A[a:(a+nr-1)])
574574
end
575575
B
576576
end
@@ -706,15 +706,15 @@ end
706706
Base.size(P::ProductIndices) = P.sz
707707
# This gets passed to map to avoid breaking propagation of inbounds
708708
Base.@propagate_inbounds propagate_getindex(A, I...) = A[I...]
709-
Base.@propagate_inbounds Base.getindex(P::ProductIndices{_,N}, I::Vararg{Int, N}) where {_,N} =
709+
Base.@propagate_inbounds Base.getindex(P::ProductIndices{J,N}, I::Vararg{Int, N}) where {J,N} =
710710
Bool((&)(map(propagate_getindex, P.indices, I)...))
711711

712712
struct MergedIndices{I,N} <: AbstractArray{CartesianIndex{N}, N}
713713
indices::I
714714
sz::NTuple{N,Int}
715715
end
716716
Base.size(M::MergedIndices) = M.sz
717-
Base.@propagate_inbounds Base.getindex(M::MergedIndices{_,N}, I::Vararg{Int, N}) where {_,N} =
717+
Base.@propagate_inbounds Base.getindex(M::MergedIndices{J,N}, I::Vararg{Int, N}) where {J,N} =
718718
CartesianIndex(map(propagate_getindex, M.indices, I))
719719
# Additionally, we optimize bounds checking when using MergedIndices as an
720720
# array index since checking, e.g., A[1:500, 1:500] is *way* faster than

src/mapreduce.jl

+19-28
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,25 @@ end
2020
Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}()
2121
Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}()
2222

23-
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}, ::Type{ElType}) where {ElType}
24-
DA = find_darray(bc)
25-
DArray(I -> Array{ElType}(undef, map(length,I)), DA)
26-
end
27-
28-
"`DA = find_darray(As)` returns the first DArray among the arguments."
29-
find_darray(bc::Base.Broadcast.Broadcasted) = find_darray(bc.args)
30-
find_darray(args::Tuple) = find_darray(find_darray(args[1]), Base.tail(args))
31-
find_darray(x) = x
32-
find_darray(a::DArray, rest) = a
33-
find_darray(::Any, rest) = find_darray(rest)
34-
35-
function Base.copyto!(dest::DArray, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
36-
@sync for p in procs(dest)
37-
@async remotecall_fetch(p) do
38-
copyto!(localpart(dest), rewrite_local(bc))
39-
end
23+
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
24+
T = Base.Broadcast.combine_eltypes(bc.f, bc.args)
25+
shape = Base.Broadcast.combine_axes(bc.args...)
26+
iter = Base.CartesianIndices(shape)
27+
D = DArray(map(length, shape)) do I
28+
A = map(bc.args) do a
29+
if isa(a, Union{Number,Ref})
30+
return a
31+
else
32+
return localtype(a)(
33+
a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...]
34+
)
35+
end
36+
end
37+
broadcast(bc.f, A...)
4038
end
41-
dest
39+
return D
4240
end
4341

44-
"""
45-
Transform a Broadcasted{Broadcast.ArrayStyle{DArray}} object into an equivalent
46-
Broadcasted{Broadcast.DefaultArrayStyle} object for the localparts.
47-
"""
48-
rewrite_local(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) = Broadcast.broadcasted(bc.f, rewrite_local(bc.args)...)
49-
rewrite_local(args::Tuple) = map(rewrite_local, args)
50-
rewrite_local(a::DArray) = localpart(a)
51-
rewrite_local(x) = x
52-
53-
5442
function Base.reduce(f, d::DArray)
5543
results = asyncmap(procs(d)) do p
5644
remotecall_fetch(p) do
@@ -128,6 +116,7 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray)
128116
return mapreducedim_between!(identity, op, R, B, region)
129117
end
130118

119+
## Some special cases
131120
function Base._all(f, A::DArray, ::Colon)
132121
B = asyncmap(procs(A)) do p
133122
remotecall_fetch(p) do
@@ -171,6 +160,8 @@ function Base.extrema(d::DArray)
171160
return reduce((t,s) -> (min(t[1], s[1]), max(t[2], s[2])), r)
172161
end
173162

163+
Statistics._mean(A::DArray, region) = sum(A, dims = region) ./ prod((size(A, i) for i in region))
164+
174165
# Unary vector functions
175166
(-)(D::DArray) = map(-, D)
176167

test/darray.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ check_leaks()
315315
@testset "test statistical functions on DArrays" begin
316316
dims = (20,20,20)
317317
DA = drandn(dims)
318-
A = convert(Array, DA)
318+
A = Array(DA)
319319

320320
@testset "test $f for dimension $dms" for f in (mean, ), dms in (1, 2, 3, (1,2), (1,3), (2,3), (1,2,3))
321321
# std is pending implementation
@@ -835,6 +835,10 @@ check_leaks()
835835
c = a .- m
836836
d = convert(Array, a) .- convert(Array, m)
837837
@test c == d
838+
e = @DArray [ones(10) for i=1:4]
839+
f = 2 .* e
840+
@test Array(f) == 2 .* Array(e)
841+
@test Array(map(x -> sum(x) .+ 2, e)) == map(x -> sum(x) .+ 2, e)
838842
d_closeall()
839843
end
840844

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
@everywhere using Random
2121
@everywhere using LinearAlgebra
2222

23-
@everywhere srand(1234 + myid())
23+
@everywhere Random.seed!(1234 + myid())
2424

2525
const MYID = myid()
2626
const OTHERIDS = filter(id-> id != MYID, procs())[rand(1:(nprocs()-1))]

0 commit comments

Comments
 (0)