Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 6a9b642

Browse files
authored
Convert NeuMF-MovieLens to Epoches (#605)
1 parent 47a4eaa commit 6a9b642

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

Datasets/MovieLens/MovieLens.swift

+37-5
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,29 @@ extension Sequence where Iterator.Element: Hashable {
3434
}
3535
}
3636

37-
public struct MovieLens {
37+
public struct MovieLens<Entropy: RandomNumberGenerator> {
3838
public let trainUsers: [Float]
3939
public let testUsers: [Float]
4040
public let testData: [[Float]]
4141
public let items: [Float]
4242
public let numUsers: Int
43-
public let numItems: Int
44-
public let trainMatrix: [TensorPair<Int32, Float>]
43+
public let numItems: Int
4544
public let user2id: [Float: Int]
4645
public let id2user: [Int: Float]
4746
public let item2id: [Float: Int]
4847
public let id2item: [Int: Float]
4948
public let trainNegSampling: Tensor<Float>
5049

50+
public typealias Samples = [TensorPair<Int32, Float>]
51+
public typealias Batches = Slices<Sampling<Samples, ArraySlice<Int>>>
52+
public typealias BatchedTensorPair = TensorPair<Int32, Float>
53+
public typealias Training = LazyMapSequence<
54+
TrainingEpochs<Samples, Entropy>,
55+
LazyMapSequence<Batches, BatchedTensorPair>
56+
>
57+
public let trainMatrix: Samples
58+
public let training: Training
59+
5160
static func downloadMovieLensDatasetIfNotPresent() -> URL {
5261
let localURL = DatasetUtilities.defaultDirectory.appendingPathComponent(
5362
"MovieLens", isDirectory: true)
@@ -60,7 +69,9 @@ public struct MovieLens {
6069
return dataFolder
6170
}
6271

63-
public init() {
72+
public init(
73+
trainBatchSize: Int = 1024,
74+
entropy: Entropy) {
6475
let trainFiles = try! String(
6576
contentsOf: MovieLens.downloadMovieLensDatasetIfNotPresent().appendingPathComponent(
6677
"u1.base"), encoding: .utf8)
@@ -127,7 +138,28 @@ public struct MovieLens {
127138
self.id2user = id2user
128139
self.item2id = item2id
129140
self.id2item = id2item
130-
self.trainMatrix = dataset
131141
self.trainNegSampling = trainNegSampling
142+
143+
self.trainMatrix = dataset
144+
self.training = TrainingEpochs(
145+
samples: trainMatrix,
146+
batchSize: trainBatchSize,
147+
entropy: entropy
148+
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, BatchedTensorPair> in
149+
batches.lazy.map {
150+
TensorPair<Int32, Float> (
151+
first: Tensor<Int32>($0.map(\.first)),
152+
second: Tensor<Float>($0.map(\.second))
153+
)
154+
}
155+
}
156+
}
157+
}
158+
159+
extension MovieLens where Entropy == SystemRandomNumberGenerator {
160+
public init(trainBatchSize: Int = 1024) {
161+
self.init(
162+
trainBatchSize: trainBatchSize,
163+
entropy: SystemRandomNumberGenerator())
132164
}
133165
}

Examples/NeuMF-MovieLens/main.swift

+6-8
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,15 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Batcher
1615
import Datasets
1716
import Foundation
1817
import RecommendationModels
1918
import TensorFlow
2019

21-
let dataset = MovieLens()
20+
let dataset = MovieLens(trainBatchSize: 1024)
2221
let numUsers = dataset.numUsers
2322
let numItems = dataset.numItems
2423

25-
let batcher = Batcher(on: dataset.trainMatrix, batchSize: 1024, shuffle: true)
26-
2724
let size: [Int] = [16, 32, 16, 8]
2825
let regs: [Float] = [0.0, 0.0, 0.0, 0.0]
2926

@@ -48,12 +45,13 @@ for element in dataset.testData {
4845
print("Dataset acquired.")
4946

5047
print("Starting training...")
51-
for epoch in 1...20 {
48+
let epochCount = 20
49+
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
5250
var avgLoss: Float = 0.0
5351
Context.local.learningPhase = .training
54-
for data in batcher.sequenced() {
55-
let userId = data.first
56-
let rating = data.second
52+
for batch in epochBatches {
53+
let userId = batch.first
54+
let rating = batch.second
5755
let (loss, grad) = valueWithGradient(at: model) { model -> Tensor<Float> in
5856
let logits = model(userId)
5957
return sigmoidCrossEntropy(logits: logits, labels: rating)

0 commit comments

Comments
 (0)