Skip to content

Commit 450e861

Browse files
committed
Adjust chunksize in gwas_linear_regression to reduce data transfer between workers.
See discussion in https://github.com/pystatgen/sgkit/issues/390#issuecomment-768332568.
1 parent 41827f3 commit 450e861

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

sgkit/stats/association.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def linear_regression(
7171
# what are effectively OLS residuals rather than matrix inverse
7272
# to avoid need for MxM array; additionally, dask.lstsq fails
7373
# with numpy arrays
74-
XLP = XL - XC @ da.linalg.lstsq(XC, XL)[0]
74+
LS = XC @ da.linalg.lstsq(XC, XL)[0]
75+
assert XL.chunksize == LS.chunksize
76+
XLP = XL - LS
7577
assert XLP.shape == (n_obs, n_loop_covar)
7678
YP = Y - XC @ da.linalg.lstsq(XC, Y)[0]
7779
assert YP.shape == (n_obs, n_outcome)
@@ -213,8 +215,10 @@ def gwas_linear_regression(
213215
# Note: dask qr decomp (used by lstsq) requires no chunking in one
214216
# dimension, and because dim 0 will be far greater than the number
215217
# of covariates for the large majority of use cases, chunking
216-
# should be removed from dim 1
217-
X = X.rechunk((None, -1))
218+
# should be removed from dim 1. Also, dim 0 should have the same chunking
219+
# as G dim 1, so that when XLP is computed in linear_regression() the
220+
# two arrays have the same chunking.
221+
X = X.rechunk((G.chunksize[1], -1))
218222

219223
Y = da.asarray(concat_2d(ds[list(traits)], dims=("samples", "traits")))
220224
# Like covariates, traits must also be tall-skinny arrays

0 commit comments

Comments
 (0)