|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | +package com.github.cloudml.zen.examples.ml |
| 18 | + |
| 19 | +import breeze.linalg.{SparseVector => BSV} |
| 20 | +import com.github.cloudml.zen.ml.recommendation.FM |
| 21 | +import org.apache.log4j.{Level, Logger} |
| 22 | +import org.apache.spark.graphx.GraphXUtils |
| 23 | +import org.apache.spark.mllib.linalg.{SparseVector => SSV} |
| 24 | +import org.apache.spark.mllib.regression.LabeledPoint |
| 25 | +import org.apache.spark.storage.StorageLevel |
| 26 | +import org.apache.spark.{SparkConf, SparkContext} |
| 27 | +import scopt.OptionParser |
| 28 | + |
| 29 | +object MovieLensFM { |
| 30 | + |
| 31 | + case class Params( |
| 32 | + input: String = null, |
| 33 | + out: String = null, |
| 34 | + numIterations: Int = 40, |
| 35 | + stepSize: Double = 0.1, |
| 36 | + regular: String = "0.01,0.01,0.01", |
| 37 | + rank: Int = 20, |
| 38 | + useAdaGrad: Boolean = false, |
| 39 | + kryo: Boolean = false) extends AbstractParams[Params] |
| 40 | + |
| 41 | + def main(args: Array[String]) { |
| 42 | + val defaultParams = Params() |
| 43 | + val parser = new OptionParser[Params]("FM") { |
| 44 | + head("MovieLensFM: an example app for FM.") |
| 45 | + opt[Int]("numIterations") |
| 46 | + .text(s"number of iterations, default: ${defaultParams.numIterations}") |
| 47 | + .action((x, c) => c.copy(numIterations = x)) |
| 48 | + opt[Int]("rank") |
| 49 | + .text(s"dim of 2-way interactions, default: ${defaultParams.rank}") |
| 50 | + .action((x, c) => c.copy(rank = x)) |
| 51 | + opt[Unit]("kryo") |
| 52 | + .text("use Kryo serialization") |
| 53 | + .action((_, c) => c.copy(kryo = true)) |
| 54 | + opt[Double]("stepSize") |
| 55 | + .text(s"stepSize, default: ${defaultParams.stepSize}") |
| 56 | + .action((x, c) => c.copy(stepSize = x)) |
| 57 | + opt[String]("regular") |
| 58 | + .text( |
| 59 | + s""" |
| 60 | + |’r0,r1,r2’ for SGD and ALS: r0=bias regularization, |
| 61 | + |r1=1-way regularization, r2=2-way regularization, default: ${defaultParams.regular} (auto) |
| 62 | + """.stripMargin) |
| 63 | + .action((x, c) => c.copy(regular = x)) |
| 64 | + opt[Unit]("adagrad") |
| 65 | + .text("use AdaGrad") |
| 66 | + .action((_, c) => c.copy(useAdaGrad = true)) |
| 67 | + arg[String]("<input>") |
| 68 | + .required() |
| 69 | + .text("input paths") |
| 70 | + .action((x, c) => c.copy(input = x)) |
| 71 | + arg[String]("<out>") |
| 72 | + .required() |
| 73 | + .text("out paths (model)") |
| 74 | + .action((x, c) => c.copy(out = x)) |
| 75 | + note( |
| 76 | + """ |
| 77 | + |For example, the following command runs this app on a synthetic dataset: |
| 78 | + | |
| 79 | + | bin/spark-submit --class com.github.cloudml.zen.examples.ml.MovieLensFM \ |
| 80 | + | examples/target/scala-*/zen-examples-*.jar \ |
| 81 | + | --rank 10 --numIterations 50 --regular 0.01,0.01,0.01 --kryo \ |
| 82 | + | data/mllib/sample_movielens_data.txt |
| 83 | + | data/mllib/fm_model |
| 84 | + """.stripMargin) |
| 85 | + } |
| 86 | + |
| 87 | + parser.parse(args, defaultParams).map { params => |
| 88 | + run(params) |
| 89 | + } getOrElse { |
| 90 | + System.exit(1) |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + def run(params: Params): Unit = { |
| 95 | + val Params( |
| 96 | + input, |
| 97 | + out, |
| 98 | + numIterations, |
| 99 | + stepSize, |
| 100 | + regular, |
| 101 | + rank, |
| 102 | + useAdaGrad, |
| 103 | + kryo) = params |
| 104 | + val conf = new SparkConf().setAppName(s"FM with $params") |
| 105 | + if (kryo) { |
| 106 | + GraphXUtils.registerKryoClasses(conf) |
| 107 | + // conf.set("spark.kryoserializer.buffer.mb", "8") |
| 108 | + } |
| 109 | + Logger.getRootLogger.setLevel(Level.WARN) |
| 110 | + val sc = new SparkContext(conf) |
| 111 | + val movieLens = sc.textFile(input).mapPartitions { iter => |
| 112 | + iter.filter(t => !t.startsWith("userId") && !t.isEmpty).map { line => |
| 113 | + val Array(userId, movieId, rating, timestamp) = line.split("::") |
| 114 | + (userId.toInt, (movieId.toInt, rating.toDouble)) |
| 115 | + } |
| 116 | + }.persist(StorageLevel.MEMORY_AND_DISK) |
| 117 | + val maxMovieId = movieLens.map(_._2._1).max + 1 |
| 118 | + val maxUserId = movieLens.map(_._1).max + 1 |
| 119 | + val numFeatures = maxUserId + 2 * maxMovieId |
| 120 | + val dataSet = movieLens.map { case (userId, (movieId, rating)) => |
| 121 | + val sv = BSV.zeros[Double](maxMovieId) |
| 122 | + sv(movieId) = rating |
| 123 | + (userId, sv) |
| 124 | + }.reduceByKey(_ :+= _).flatMap { case (userId, ratings) => |
| 125 | + val activeSize = ratings.activeSize |
| 126 | + ratings.activeIterator.map { case (movieId, rating) => |
| 127 | + val sv = BSV.zeros[Double](numFeatures) |
| 128 | + sv(userId) = 1.0 |
| 129 | + sv(movieId + maxUserId) = 1.0 |
| 130 | + ratings.activeKeysIterator.foreach { mId => |
| 131 | + sv(maxMovieId + maxUserId + mId) = 1.0 / math.sqrt(activeSize) |
| 132 | + } |
| 133 | + new LabeledPoint(rating, new SSV(sv.length, sv.index.slice(0, sv.used), sv.data.slice(0, sv.used))) |
| 134 | + } |
| 135 | + }.zipWithIndex().map(_.swap).persist(StorageLevel.MEMORY_AND_DISK) |
| 136 | + dataSet.count() |
| 137 | + movieLens.unpersist() |
| 138 | + |
| 139 | + val regs = regular.split(",").map(_.toDouble) |
| 140 | + val l2 = (regs(0), regs(1), regs(2)) |
| 141 | + val model = FM.trainRegression(dataSet, numIterations, stepSize, l2, rank, useAdaGrad, 1.0) |
| 142 | + model.save(sc, out) |
| 143 | + } |
| 144 | +} |
0 commit comments