82
82
83
83
84
84
"""
85
- trainGigaSOM(som::Som, dInfo::LoadedDataInfo;
86
- kernelFun::Function = gaussianKernel,
87
- metric = Euclidean(),
88
- somDistFun = distMatrix(Chebyshev()),
89
- knnTreeFun = BruteTree,
90
- rStart = 0.0, rFinal=0.1, radiusFun=expRadius(-5.0),
91
- epochs = 20)
85
+ trainGigaSOM(
86
+ som::Som,
87
+ dInfo::LoadedDataInfo;
88
+ kernelFun::Function = gaussianKernel,
89
+ metric = Euclidean(),
90
+ somDistFun = distMatrix(Chebyshev()),
91
+ knnTreeFun = BruteTree,
92
+ rStart = 0.0,
93
+ rFinal = 0.1,
94
+ radiusFun = expRadius(-5.0),
95
+ epochs = 20,
96
+ eachEpoch = (e, r, som) -> nothing,
97
+ )
92
98
93
99
# Arguments:
94
100
- `som`: object of type Som with an initialised som
102
108
- `rFinal`: target radius at the last epoch, defaults to 0.1
103
109
- `radiusFun`: Function that generates radius decay, e.g. `linearRadius` or `expRadius(10.0)`
104
110
- `epochs`: number of SOM training iterations (default 10)
111
+ - `eachEpoch`: a function to call back after each epoch, accepting arguments
112
+ `(epochNumber, radius, som)`. For simplicity, this gets additionally called
113
+ once before the first epoch, with `epochNumber` set to zero.
105
114
"""
106
115
function trainGigaSOM (
107
116
som:: Som ,
@@ -114,6 +123,7 @@ function trainGigaSOM(
114
123
rFinal = 0.1 ,
115
124
radiusFun = expRadius (- 5.0 ),
116
125
epochs = 20 ,
126
+ eachEpoch = (e, r, som) -> nothing ,
117
127
)
118
128
119
129
# set the default radius
@@ -125,75 +135,51 @@ function trainGigaSOM(
125
135
# get the SOM neighborhood distances
126
136
dm = somDistFun (som. grid)
127
137
128
- codes = som. codes
138
+ result_som = copy (som)
139
+ result_som. codes = copy (som. codes) # prevent rewriting by reference
140
+
141
+ eachEpoch (0 , rStart, result_som)
129
142
130
- for j = 1 : epochs
131
- @debug " Epoch $j ..."
143
+ for epoch = 1 : epochs
144
+ @debug " Epoch $epoch ..."
132
145
133
146
numerator, denominator = distributedEpoch (
134
147
dInfo,
135
- codes,
136
- knnTreeFun (Array {Float64,2} (transpose (codes)), metric),
148
+ result_som . codes,
149
+ knnTreeFun (Array {Float64,2} (transpose (result_som . codes)), metric),
137
150
)
138
151
139
- r = radiusFun (rStart, rFinal, j , epochs)
152
+ r = radiusFun (rStart, rFinal, epoch , epochs)
140
153
@debug " radius: $r "
141
154
if r <= 0
142
155
@error " Sanity check failed: radius must be positive"
143
156
error (" Radius check" )
144
157
end
145
158
146
159
wEpoch = kernelFun (dm, r)
147
- codes = (wEpoch * numerator) ./ (wEpoch * denominator)
148
- end
160
+ result_som. codes = (wEpoch * numerator) ./ (wEpoch * denominator)
149
161
150
- som. codes = copy (codes)
162
+ eachEpoch (epoch, r, result_som)
163
+ end
151
164
152
- return som
165
+ return result_som
153
166
end
154
167
155
168
"""
156
169
trainGigaSOM(som::Som, train;
157
- kernelFun::Function = gaussianKernel,
158
- metric = Euclidean(),
159
- somDistFun = distMatrix(Chebyshev()),
160
- knnTreeFun = BruteTree,
161
- rStart = 0.0, rFinal=0.1, radiusFun=expRadius(-5.0),
162
- epochs = 20)
170
+ kwargs...)
163
171
164
172
Overload of `trainGigaSOM` for simple DataFrames and matrices. This slices the
165
- data using `DistributedArrays`, sends them the workers, and runs normal
166
- `trainGigaSOM`. Data is ` undistribute`d after the computation.
173
+ data, distributes them to the workers, and runs normal `trainGigaSOM`. Data is
174
+ `undistribute`d after the computation.
167
175
"""
168
- function trainGigaSOM (
169
- som:: Som ,
170
- train;
171
- kernelFun:: Function = gaussianKernel,
172
- metric = Euclidean (),
173
- somDistFun = distMatrix (Chebyshev ()),
174
- knnTreeFun = BruteTree,
175
- rStart = 0.0 ,
176
- rFinal = 0.1 ,
177
- radiusFun = expRadius (- 5.0 ),
178
- epochs = 20 ,
179
- )
176
+ function trainGigaSOM (som:: Som , train; kwargs... )
180
177
181
178
train = Matrix {Float64} (train)
182
179
183
180
# this slices the data into parts and and sends them to workers
184
181
dInfo = distribute_array (:GigaSOMtrainDataVar , train, workers ())
185
- som_res = trainGigaSOM (
186
- som,
187
- dInfo,
188
- kernelFun = kernelFun,
189
- metric = metric,
190
- somDistFun = somDistFun,
191
- knnTreeFun = knnTreeFun,
192
- rStart = rStart,
193
- rFinal = rFinal,
194
- radiusFun = radiusFun,
195
- epochs = epochs,
196
- )
182
+ som_res = trainGigaSOM (som, dInfo; kwargs... )
197
183
undistribute (dInfo)
198
184
return som_res
199
185
end
0 commit comments