From 2eff6ab648647ee453041c355feff1062d8f06f8 Mon Sep 17 00:00:00 2001 From: Sarah Gilmore Date: Fri, 10 May 2024 14:40:48 -0400 Subject: [PATCH] Add functionality for converting mlarrow record batches to pyarrow record batches --- matlab/src/matlab/+arrow/+c/mlarrow2pyarrow.m | 17 +++++++++++++---- matlab/src/matlab/+arrow/+c/pyarrow2mlarrow.m | 8 ++++++-- matlab/src/matlab/+arrow/+tabular/RecordBatch.m | 11 +++++++++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/matlab/src/matlab/+arrow/+c/mlarrow2pyarrow.m b/matlab/src/matlab/+arrow/+c/mlarrow2pyarrow.m index 96e36fd8ffd94..2cd3126a1c522 100644 --- a/matlab/src/matlab/+arrow/+c/mlarrow2pyarrow.m +++ b/matlab/src/matlab/+arrow/+c/mlarrow2pyarrow.m @@ -1,13 +1,22 @@ function pyarrowArray = mlarrow2pyarrow(mlarrowArray) - [cArrayWrapper, cSchemaWrapper] = pyrunfile(fullfile(pwd, "+internal/FFIWrapper.py"), ["cArrayWrapper" "cSchemaWrapper"]); + folder = fileparts(mfilename("fullpath")); + + [cArrayWrapper, cSchemaWrapper] = pyrunfile(fullfile(folder, "+internal/FFIWrapper.py"), ["cArrayWrapper" "cSchemaWrapper"]); cArrayAddress = uint64(cArrayWrapper.getAddress()); cSchemaAdress = uint64(cSchemaWrapper.getAddress()); mlarrowArray.exportToC(cArrayAddress, cSchemaAdress); - dummyArray = py.pyarrow.array([1 2]); - importFunc = py.getattr(dummyArray, "_import_from_c"); - pyarrowArray = importFunc(cArrayAddress, cSchemaAdress); + if isa(mlarrowArray, "arrow.array.Array") + dummyArray = py.pyarrow.array([1 2]); + importFunc = py.getattr(dummyArray, "_import_from_c"); + pyarrowArray = importFunc(cArrayAddress, cSchemaAdress); + else + dummyArray = py.pyarrow.array([1 2]); + dummyRB = py.pyarrow.record_batch(py.list({dummyArray}), names={'Var1'}); + importFunc = py.getattr(dummyRB, "_import_from_c"); + pyarrowArray = importFunc(cArrayAddress, cSchemaAdress); + end end diff --git a/matlab/src/matlab/+arrow/+c/pyarrow2mlarrow.m b/matlab/src/matlab/+arrow/+c/pyarrow2mlarrow.m index 3aac81cd54cb6..b5d7bb758b284 100644 --- a/matlab/src/matlab/+arrow/+c/pyarrow2mlarrow.m +++ b/matlab/src/matlab/+arrow/+c/pyarrow2mlarrow.m @@ -5,6 +5,10 @@ exportFunc = py.getattr(pyarrowArray, "_export_to_c"); exportFunc(cArray.Address, cSchema.Address); - mlarrowArray = arrow.array.Array.importFromC(cArray, cSchema); - + if isa(pyarrowArray, "py.pyarrow.lib.Array") + mlarrowArray = arrow.array.Array.importFromC(cArray, cSchema); + else + importer = arrow.c.internal.RecordBatchImporter(); + mlarrowArray = importer.import(cArray, cSchema); + end end \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m index eb4e089221216..1c90eb6c58d74 100644 --- a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m +++ b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m @@ -144,8 +144,15 @@ function displayScalarObject(obj) end methods(Hidden) - function exportToC(cArrayAddress, cSchemaAddress) - obj.Proxy.exportToC(cArrayAddress, cSchemaAddress); + function exportToC(obj, cArrayAddress, cSchemaAddress) + arguments + obj(1, 1) arrow.tabular.RecordBatch + cArrayAddress(1, 1) uint64 + cSchemaAddress(1, 1) uint64 + end + args = struct(ArrowArrayAddress=cArrayAddress,... + ArrowSchemaAddress=cSchemaAddress); + obj.Proxy.exportToC(args); end end end