Skip to content

Commit 250b583

Browse files
cocoa-xupolvalente
andauthored
chore: bump mlx to v0.25.1 (#77)
* chore: bump mlx to v0.23.0 Signed-off-by: Cocoa <[email protected]> * update to 0.25.1 and fix to_blob --------- Signed-off-by: Cocoa <[email protected]> Co-authored-by: Paulo Valente <[email protected]>
1 parent 0bf7e72 commit 250b583

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

c_src/emlx_nif.cpp

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,15 @@ NIF(to_blob) {
275275
int limit = 0;
276276
bool has_received_limit = (argc == 2);
277277

278+
// Evaluate to ensure data is available
279+
t->eval();
280+
278281
if (has_received_limit) {
279282
PARAM(1, int, param_limit);
280283
limit = param_limit;
281284
byte_size = limit * t->itemsize();
282285
}
283286

284-
// Flatten and slice if needed
285-
mlx::core::array flattened = mlx::core::flatten(*t);
286-
mlx::core::array reshaped =
287-
(has_received_limit && byte_size < t->nbytes())
288-
? mlx::core::slice(flattened, std::vector<int>{0},
289-
std::vector<int>{limit})
290-
: flattened;
291-
292-
// Evaluate to ensure data is available
293-
mlx::core::eval(reshaped);
294-
295287
// Create result binary
296288
void *result_data = (void *)enif_make_new_binary(env, byte_size, &result);
297289

@@ -300,14 +292,12 @@ NIF(to_blob) {
300292
// https://github.com/ml-explore/mlx/discussions/1608#discussioncomment-11332071
301293
//
302294
// Set up contiguous iterator
303-
std::vector<int> slice_sizes(reshaped.shape().begin(),
304-
reshaped.shape().end());
305-
ContiguousIterator iterator(slice_sizes, reshaped.strides(),
306-
reshaped.ndim());
295+
std::vector<int> slice_sizes(t->shape().begin(), t->shape().end());
296+
ContiguousIterator iterator(slice_sizes, t->strides(), t->ndim());
307297

308298
// Copy data element by element using iterator
309-
size_t element_size = reshaped.itemsize();
310-
const char *src_data = static_cast<const char *>(reshaped.data<void>());
299+
size_t element_size = t->itemsize();
300+
const char *src_data = static_cast<const char *>(t->data<void>());
311301
char *dst_data = static_cast<char *>(result_data);
312302

313303
size_t num_elements = byte_size / element_size;

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defmodule EMLX.MixProject do
33

44
@app :emlx
55
@version "0.1.2"
6-
@mlx_version "0.22.1"
6+
@mlx_version "0.25.1"
77

88
require Logger
99

0 commit comments

Comments
 (0)