Skip to content

Commit e7fa778

Browse files
committed
[SPARK-30699][ML][PYSPARK] GMM blockify input vectors
### What changes were proposed in this pull request? 1, add new param blockSize; 2, if blockSize==1, keep original behavior, code path trainOnRows; 3, if blockSize>1, standardize and stack input vectors to blocks (like ALS/MLP), code path trainOnBlocks ### Why are the changes needed? performance gain on dense dataset HIGGS: 1, save about 45% RAM; 2, 3X faster with openBLAS ### Does this PR introduce any user-facing change? add a new expert param `blockSize` ### How was this patch tested? added testsuites Closes apache#27473 from zhengruifeng/blockify_gmm. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent a89006a commit e7fa778

File tree

6 files changed

+325
-76
lines changed

6 files changed

+325
-76
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ private[spark] object BLAS extends Serializable {
271271
}
272272

273273
/**
274-
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
274+
* Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR.
275275
*
276276
* @param U the upper triangular part of the matrix packed in an array (column major)
277277
*/

mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala

+31-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class MultivariateGaussian @Since("2.0.0") (
5555
*/
5656
@transient private lazy val tuple = {
5757
val (rootSigmaInv, u) = calculateCovarianceConstants
58-
val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv)
58+
val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv).toDense
5959
val rootSigmaInvMulMu = rootSigmaInvMat.multiply(mean)
6060
(rootSigmaInvMat, u, rootSigmaInvMulMu)
6161
}
@@ -81,6 +81,36 @@ class MultivariateGaussian @Since("2.0.0") (
8181
u - 0.5 * BLAS.dot(v, v)
8282
}
8383

84+
private[ml] def pdf(X: Matrix): DenseVector = {
85+
val mat = DenseMatrix.zeros(X.numRows, X.numCols)
86+
pdf(X, mat)
87+
}
88+
89+
private[ml] def pdf(X: Matrix, mat: DenseMatrix): DenseVector = {
90+
require(!mat.isTransposed)
91+
92+
BLAS.gemm(1.0, X, rootSigmaInvMat.transpose, 0.0, mat)
93+
val m = mat.numRows
94+
val n = mat.numCols
95+
96+
val pdfVec = mat.multiply(rootSigmaInvMulMu)
97+
98+
val blas = BLAS.getBLAS(n)
99+
val squared1 = blas.ddot(n, rootSigmaInvMulMu.values, 1, rootSigmaInvMulMu.values, 1)
100+
101+
val localU = u
102+
var i = 0
103+
while (i < m) {
104+
val squared2 = blas.ddot(n, mat.values, i, m, mat.values, i, m)
105+
val dot = pdfVec(i)
106+
val squaredSum = squared1 + squared2 - dot - dot
107+
pdfVec.values(i) = math.exp(localU - 0.5 * squaredSum)
108+
i += 1
109+
}
110+
111+
pdfVec
112+
}
113+
84114
/**
85115
* Calculate distribution dependent components used for the density function:
86116
* pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))

mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala

+10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
2727
test("univariate") {
2828
val x1 = Vectors.dense(0.0)
2929
val x2 = Vectors.dense(1.5)
30+
val mat = Matrices.fromVectors(Seq(x1, x2))
3031

3132
val mu = Vectors.dense(0.0)
3233
val sigma1 = Matrices.dense(1, 1, Array(1.0))
@@ -35,18 +36,21 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
3536
assert(dist1.logpdf(x2) ~== -2.0439385332046727 absTol 1E-5)
3637
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
3738
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
39+
assert(dist1.pdf(mat) ~== Vectors.dense(0.39894, 0.12952) absTol 1E-5)
3840

3941
val sigma2 = Matrices.dense(1, 1, Array(4.0))
4042
val dist2 = new MultivariateGaussian(mu, sigma2)
4143
assert(dist2.logpdf(x1) ~== -1.612085713764618 absTol 1E-5)
4244
assert(dist2.logpdf(x2) ~== -1.893335713764618 absTol 1E-5)
4345
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
4446
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
47+
assert(dist2.pdf(mat) ~== Vectors.dense(0.19947, 0.15057) absTol 1E-5)
4548
}
4649

4750
test("multivariate") {
4851
val x1 = Vectors.dense(0.0, 0.0)
4952
val x2 = Vectors.dense(1.0, 1.0)
53+
val mat = Matrices.fromVectors(Seq(x1, x2))
5054

5155
val mu = Vectors.dense(0.0, 0.0)
5256
val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
@@ -55,28 +59,33 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
5559
assert(dist1.logpdf(x2) ~== -2.8378770664093453 absTol 1E-5)
5660
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
5761
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
62+
assert(dist1.pdf(mat) ~== Vectors.dense(0.15915, 0.05855) absTol 1E-5)
5863

5964
val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
6065
val dist2 = new MultivariateGaussian(mu, sigma2)
6166
assert(dist2.logpdf(x1) ~== -2.810832140937002 absTol 1E-5)
6267
assert(dist2.logpdf(x2) ~== -3.3822607123655732 absTol 1E-5)
6368
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
6469
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
70+
assert(dist2.pdf(mat) ~== Vectors.dense(0.060155, 0.033971) absTol 1E-5)
6571
}
6672

6773
test("multivariate degenerate") {
6874
val x1 = Vectors.dense(0.0, 0.0)
6975
val x2 = Vectors.dense(1.0, 1.0)
76+
val mat = Matrices.fromVectors(Seq(x1, x2))
7077

7178
val mu = Vectors.dense(0.0, 0.0)
7279
val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
7380
val dist = new MultivariateGaussian(mu, sigma)
7481
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
7582
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
83+
assert(dist.pdf(mat) ~== Vectors.dense(0.11254, 0.068259) absTol 1E-5)
7684
}
7785

7886
test("SPARK-11302") {
7987
val x = Vectors.dense(629, 640, 1.7188, 618.19)
88+
val mat = Matrices.fromVectors(Seq(x))
8089
val mu = Vectors.dense(
8190
1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
8291
val sigma = Matrices.dense(4, 4, Array(
@@ -87,5 +96,6 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
8796
val dist = new MultivariateGaussian(mu, sigma)
8897
// Agrees with R's dmvnorm: 7.154782e-05
8998
assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
99+
assert(dist.pdf(mat) ~== Vectors.dense(7.154782224045512E-5) absTol 1E-5)
90100
}
91101
}

0 commit comments

Comments
 (0)