Skip to content

Commit b934138

Browse files
authored
positionIds are not weights (#251)
1 parent 651ec4a commit b934138

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

Libraries/MLXVLM/Models/Paligemma.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ private enum Vision {
336336
@ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding
337337

338338
let positions: Int
339-
let positionIds: MLXArray
339+
let _positionIds: MLXArray
340340

341341
public init(_ config: PaliGemmaConfiguration.VisionConfiguration) {
342342
self._patchEmbedding.wrappedValue = Conv2d(
@@ -348,13 +348,13 @@ private enum Vision {
348348
self._positionEmbedding.wrappedValue = Embedding(
349349
embeddingCount: positions, dimensions: config.hiddenSize
350350
)
351-
self.positionIds = MLXArray(0 ..< positions)[.newAxis, 0...]
351+
self._positionIds = MLXArray(0 ..< positions)[.newAxis, 0...]
352352
}
353353

354354
public func callAsFunction(_ x: MLXArray) -> MLXArray {
355355
var patchEmbeddings = self.patchEmbedding(x)
356356
patchEmbeddings = patchEmbeddings.flattened(start: 1, end: 2)
357-
let embeddings = patchEmbeddings + self.positionEmbedding(self.positionIds)
357+
let embeddings = patchEmbeddings + self.positionEmbedding(self._positionIds)
358358
return embeddings
359359
}
360360
}

0 commit comments

Comments
 (0)