Skip to content

Commit db2709f

Browse files
committed
Fix mean with dimension argument
1 parent ab4904c commit db2709f

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

src/DistributedArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module DistributedArrays
55
using Distributed
66
using Serialization
77
using LinearAlgebra
8+
using Statistics
89

910
import Base: +, -, *, div, mod, rem, &, |, xor
1011
import Base.Callable

src/mapreduce.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray)
116116
return mapreducedim_between!(identity, op, R, B, region)
117117
end
118118

119+
## Some special cases
119120
function Base._all(f, A::DArray, ::Colon)
120121
B = asyncmap(procs(A)) do p
121122
remotecall_fetch(p) do
@@ -159,6 +160,8 @@ function Base.extrema(d::DArray)
159160
return reduce((t,s) -> (min(t[1], s[1]), max(t[2], s[2])), r)
160161
end
161162

163+
Statistics._mean(A::DArray, region) = sum(A, dims = region) ./ prod((size(A, i) for i in region))
164+
162165
# Unary vector functions
163166
(-)(D::DArray) = map(-, D)
164167

test/darray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ check_leaks()
315315
@testset "test statistical functions on DArrays" begin
316316
dims = (20,20,20)
317317
DA = drandn(dims)
318-
A = convert(Array, DA)
318+
A = Array(DA)
319319

320320
@testset "test $f for dimension $dms" for f in (mean, ), dms in (1, 2, 3, (1,2), (1,3), (2,3), (1,2,3))
321321
# std is pending implementation

0 commit comments

Comments
 (0)