@@ -34,20 +34,29 @@ extension Sequence where Iterator.Element: Hashable {
34
34
}
35
35
}
36
36
37
- public struct MovieLens {
37
+ public struct MovieLens < Entropy : RandomNumberGenerator > {
38
38
public let trainUsers : [ Float ]
39
39
public let testUsers : [ Float ]
40
40
public let testData : [ [ Float ] ]
41
41
public let items : [ Float ]
42
42
public let numUsers : Int
43
- public let numItems : Int
44
- public let trainMatrix : [ TensorPair < Int32 , Float > ]
43
+ public let numItems : Int
45
44
public let user2id : [ Float : Int ]
46
45
public let id2user : [ Int : Float ]
47
46
public let item2id : [ Float : Int ]
48
47
public let id2item : [ Int : Float ]
49
48
public let trainNegSampling : Tensor < Float >
50
49
50
+ public typealias Samples = [ TensorPair < Int32 , Float > ]
51
+ public typealias Batches = Slices < Sampling < Samples , ArraySlice < Int > > >
52
+ public typealias BatchedTensorPair = TensorPair < Int32 , Float >
53
+ public typealias Training = LazyMapSequence <
54
+ TrainingEpochs < Samples , Entropy > ,
55
+ LazyMapSequence < Batches , BatchedTensorPair >
56
+ >
57
+ public let trainMatrix : Samples
58
+ public let training : Training
59
+
51
60
static func downloadMovieLensDatasetIfNotPresent( ) -> URL {
52
61
let localURL = DatasetUtilities . defaultDirectory. appendingPathComponent (
53
62
" MovieLens " , isDirectory: true )
@@ -60,7 +69,9 @@ public struct MovieLens {
60
69
return dataFolder
61
70
}
62
71
63
- public init ( ) {
72
+ public init (
73
+ trainBatchSize: Int = 1024 ,
74
+ entropy: Entropy ) {
64
75
let trainFiles = try ! String (
65
76
contentsOf: MovieLens . downloadMovieLensDatasetIfNotPresent ( ) . appendingPathComponent (
66
77
" u1.base " ) , encoding: . utf8)
@@ -127,7 +138,28 @@ public struct MovieLens {
127
138
self . id2user = id2user
128
139
self . item2id = item2id
129
140
self . id2item = id2item
130
- self . trainMatrix = dataset
131
141
self . trainNegSampling = trainNegSampling
142
+
143
+ self . trainMatrix = dataset
144
+ self . training = TrainingEpochs (
145
+ samples: trainMatrix,
146
+ batchSize: trainBatchSize,
147
+ entropy: entropy
148
+ ) . lazy. map { ( batches: Batches ) -> LazyMapSequence < Batches , BatchedTensorPair > in
149
+ batches. lazy. map {
150
+ TensorPair < Int32 , Float > (
151
+ first: Tensor < Int32 > ( $0. map ( \. first) ) ,
152
+ second: Tensor < Float > ( $0. map ( \. second) )
153
+ )
154
+ }
155
+ }
156
+ }
157
+ }
158
+
159
+ extension MovieLens where Entropy == SystemRandomNumberGenerator {
160
+ public init ( trainBatchSize: Int = 1024 ) {
161
+ self . init (
162
+ trainBatchSize: trainBatchSize,
163
+ entropy: SystemRandomNumberGenerator ( ) )
132
164
}
133
165
}
0 commit comments