Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 243bfb4

Browse files
authored
Quick fix for device placement. (#725)
1 parent 644f16b commit 243bfb4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

TrainingLoop/Metrics.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ public struct TopKAccuracyMeasurer: MetricsMeasurer {
137137
correctGuessCount += Int32(
138138
Tensor<Int32>(
139139
_Raw.inTopKV2(
140-
predictions: predictionsReshaped, targets: labelsReshaped, k: Tensor<Int32>(k))).sum()
140+
predictions: predictionsReshaped, targets: labelsReshaped,
141+
k: Tensor<Int32>(k, on: predictions.device))
142+
).sum()
141143
.scalar ?? 0)
142144
totalGuessCount += Int32(labels.shape.reduce(1, *))
143145
}
@@ -170,7 +172,7 @@ public struct MCCMeasurer: MetricsMeasurer {
170172
groundTruths = []
171173
}
172174

173-
/// Appends boolean values computed from `predictions` and `labels`
175+
/// Appends boolean values computed from `predictions` and `labels`
174176
/// to self.predictions and self.groundTruths.
175177
public mutating func accumulate<Output, Target>(
176178
loss: Tensor<Float>?, predictions: Output?, labels: Target?

0 commit comments

Comments
 (0)