From fe2d926ef385be58833f0e5e09d1860c63f800e6 Mon Sep 17 00:00:00 2001 From: Sarah Gilmore <74676073+sgilmore10@users.noreply.github.com> Date: Tue, 28 May 2024 09:37:54 -0400 Subject: [PATCH] GH-41803: [MATLAB] Add C Data Interface format import/export functionality for `arrow.tabular.RecordBatch` (#41817) ### Rationale for this change This pull requests adds two new APIs for importing and exporting `arrow.tabular.RecordBatch` instances using the C Data Interface format. **Example:** ```matlab >> T = table((1:3)', ["A"; "B"; "C"]); >> expected = arrow.recordBatch(T) expected = Arrow RecordBatch with 3 rows and 2 columns: Schema: Var1: Float64 | Var2: String First Row: 1 | "A" >> cArray = arrow.c.Array(); >> cSchema = arrow.c.Schema(); % Export the RecordBatch to C Data Interface Format >> expected.export(cArray.Address, cSchema.Address); % Import the RecordBatch from C Data Interface Format >> actual = arrow.tabular.RecordBatch.import(cArray, cSchema) actual = Arrow RecordBatch with 3 rows and 2 columns: Schema: Var1: Float64 | Var2: String First Row: 1 | "A" % The RecordBatch is the same after round-tripping to the C Data Interface format >> isequal(actual, expected) ans = logical 1 ``` ### What changes are included in this PR? 1. Added a new method `arrow.tabular.RecordBatch.export` for exporting `RecordBatch` objects to the C Data Interface format. 2. Added a new static method `arrow.tabular.RecordBatch.import` for importing `RecordBatch` objects from the C Data Interface format. 3. Added a new internal class `arrow.c.internal.RecordBatchImporter` for importing `RecordBatch` objects from the C Data Interface format. ### Are these changes tested? Yes. 1. Added a new test file `matlab/test/arrow/c/tRoundtripRecordBatch.m` which has basic round-trip tests for importing and exporting `RecordBatch` objects. ### Are there any user-facing changes? Yes. 1. Two new user-facing methods were added to `arrow.tabular.RecordBatch`. The first is `arrow.tabular.RecordBatch.export(cArrowArrayAddress, cArrowSchemaAddress)`. The second is `arrow.tabular.RecordBatch.import(cArray, cSchema)`. These APIs can be used to export/import `RecordBatch` objects using the C Data Interface format. ### Future Directions 1. Add integration tests for sharing data between MATLAB/mlarrow and Python/pyarrow running in the same process using the [MATLAB interface to Python](https://www.mathworks.com/help/matlab/call-python-libraries.html). 2. Add support for the Arrow [C stream interface format](https://arrow.apache.org/docs/format/CStreamInterface.html). ### Notes 1. Thanks to @ kevingurney for the help with this feature! * GitHub Issue: #41803 Authored-by: Sarah Gilmore Signed-off-by: Sarah Gilmore --- .../matlab/c/proxy/record_batch_importer.cc | 66 +++++++ .../matlab/c/proxy/record_batch_importer.h | 37 ++++ matlab/src/cpp/arrow/matlab/proxy/factory.cc | 104 +++++------ .../matlab/tabular/proxy/record_batch.cc | 19 +- .../arrow/matlab/tabular/proxy/record_batch.h | 1 + .../+arrow/+c/+internal/RecordBatchImporter.m | 52 ++++++ .../src/matlab/+arrow/+tabular/RecordBatch.m | 22 +++ matlab/test/arrow/c/tRoundTripRecordBatch.m | 170 ++++++++++++++++++ .../cmake/BuildMatlabArrowInterface.cmake | 3 +- 9 files changed, 420 insertions(+), 54 deletions(-) create mode 100644 matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.cc create mode 100644 matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.h create mode 100644 matlab/src/matlab/+arrow/+c/+internal/RecordBatchImporter.m create mode 100644 matlab/test/arrow/c/tRoundTripRecordBatch.m diff --git a/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.cc b/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.cc new file mode 100644 index 0000000000000..ed9ba14cfbe01 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.cc @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/c/bridge.h" + +#include "arrow/matlab/c/proxy/record_batch_importer.h" +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/tabular/proxy/record_batch.h" + +#include "libmexclass/proxy/ProxyManager.h" + +namespace arrow::matlab::c::proxy { + +RecordBatchImporter::RecordBatchImporter() { + REGISTER_METHOD(RecordBatchImporter, import); +} + +libmexclass::proxy::MakeResult RecordBatchImporter::make( + const libmexclass::proxy::FunctionArguments& constructor_arguments) { + return std::make_shared(); +} + +void RecordBatchImporter::import(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using namespace libmexclass::proxy; + using RecordBatchProxy = arrow::matlab::tabular::proxy::RecordBatch; + + mda::StructArray args = context.inputs[0]; + const mda::TypedArray arrow_array_address_mda = args[0]["ArrowArrayAddress"]; + const mda::TypedArray arrow_schema_address_mda = + args[0]["ArrowSchemaAddress"]; + + const auto arrow_array_address = uint64_t(arrow_array_address_mda[0]); + const auto arrow_schema_address = uint64_t(arrow_schema_address_mda[0]); + + auto arrow_array = reinterpret_cast(arrow_array_address); + auto arrow_schema = reinterpret_cast(arrow_schema_address); + + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto record_batch, + arrow::ImportRecordBatch(arrow_array, arrow_schema), + context, error::C_IMPORT_FAILED); + + auto record_batch_proxy = std::make_shared(std::move(record_batch)); + + mda::ArrayFactory factory; + const auto record_batch_proxy_id = ProxyManager::manageProxy(record_batch_proxy); + const auto record_batch_proxy_id_mda = factory.createScalar(record_batch_proxy_id); + + context.outputs[0] = record_batch_proxy_id_mda; +} + +} // namespace arrow::matlab::c::proxy diff --git a/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.h b/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.h new file mode 100644 index 0000000000000..0f697db0d25b0 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/c/proxy/record_batch_importer.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "libmexclass/proxy/Proxy.h" + +namespace arrow::matlab::c::proxy { + +class RecordBatchImporter : public libmexclass::proxy::Proxy { + public: + RecordBatchImporter(); + + ~RecordBatchImporter() = default; + + static libmexclass::proxy::MakeResult make( + const libmexclass::proxy::FunctionArguments& constructor_arguments); + + protected: + void import(libmexclass::proxy::method::Context& context); +}; + +} // namespace arrow::matlab::c::proxy diff --git a/matlab/src/cpp/arrow/matlab/proxy/factory.cc b/matlab/src/cpp/arrow/matlab/proxy/factory.cc index 9b95fcf128090..53a19da82e334 100644 --- a/matlab/src/cpp/arrow/matlab/proxy/factory.cc +++ b/matlab/src/cpp/arrow/matlab/proxy/factory.cc @@ -27,6 +27,7 @@ #include "arrow/matlab/buffer/proxy/buffer.h" #include "arrow/matlab/c/proxy/array.h" #include "arrow/matlab/c/proxy/array_importer.h" +#include "arrow/matlab/c/proxy/record_batch_importer.h" #include "arrow/matlab/c/proxy/schema.h" #include "arrow/matlab/error/error.h" #include "arrow/matlab/io/csv/proxy/table_reader.h" @@ -54,57 +55,58 @@ namespace arrow::matlab::proxy { libmexclass::proxy::MakeResult Factory::make_proxy( const ClassName& class_name, const FunctionArguments& constructor_arguments) { // clang-format off - REGISTER_PROXY(arrow.array.proxy.Float32Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Float64Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.UInt8Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.UInt16Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.UInt32Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.UInt64Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Int8Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Int16Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Int32Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Int64Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.BooleanArray , arrow::matlab::array::proxy::BooleanArray); - REGISTER_PROXY(arrow.array.proxy.StringArray , arrow::matlab::array::proxy::StringArray); - REGISTER_PROXY(arrow.array.proxy.StructArray , arrow::matlab::array::proxy::StructArray); - REGISTER_PROXY(arrow.array.proxy.ListArray , arrow::matlab::array::proxy::ListArray); - REGISTER_PROXY(arrow.array.proxy.TimestampArray, arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Time32Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Time64Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Date32Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.Date64Array , arrow::matlab::array::proxy::NumericArray); - REGISTER_PROXY(arrow.array.proxy.ChunkedArray , arrow::matlab::array::proxy::ChunkedArray); - REGISTER_PROXY(arrow.buffer.proxy.Buffer , arrow::matlab::buffer::proxy::Buffer); - REGISTER_PROXY(arrow.tabular.proxy.RecordBatch , arrow::matlab::tabular::proxy::RecordBatch); - REGISTER_PROXY(arrow.tabular.proxy.Table , arrow::matlab::tabular::proxy::Table); - REGISTER_PROXY(arrow.tabular.proxy.Schema , arrow::matlab::tabular::proxy::Schema); - REGISTER_PROXY(arrow.type.proxy.Field , arrow::matlab::type::proxy::Field); - REGISTER_PROXY(arrow.type.proxy.Float32Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.Float64Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.UInt8Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.UInt16Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.UInt32Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.UInt64Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.Int8Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.Int16Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.Int32Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.Int64Type , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.BooleanType , arrow::matlab::type::proxy::PrimitiveCType); - REGISTER_PROXY(arrow.type.proxy.StringType , arrow::matlab::type::proxy::StringType); - REGISTER_PROXY(arrow.type.proxy.TimestampType , arrow::matlab::type::proxy::TimestampType); - REGISTER_PROXY(arrow.type.proxy.Time32Type , arrow::matlab::type::proxy::Time32Type); - REGISTER_PROXY(arrow.type.proxy.Time64Type , arrow::matlab::type::proxy::Time64Type); - REGISTER_PROXY(arrow.type.proxy.Date32Type , arrow::matlab::type::proxy::Date32Type); - REGISTER_PROXY(arrow.type.proxy.Date64Type , arrow::matlab::type::proxy::Date64Type); - REGISTER_PROXY(arrow.type.proxy.StructType , arrow::matlab::type::proxy::StructType); - REGISTER_PROXY(arrow.type.proxy.ListType , arrow::matlab::type::proxy::ListType); - REGISTER_PROXY(arrow.io.feather.proxy.Writer , arrow::matlab::io::feather::proxy::Writer); - REGISTER_PROXY(arrow.io.feather.proxy.Reader , arrow::matlab::io::feather::proxy::Reader); - REGISTER_PROXY(arrow.io.csv.proxy.TableWriter , arrow::matlab::io::csv::proxy::TableWriter); - REGISTER_PROXY(arrow.io.csv.proxy.TableReader , arrow::matlab::io::csv::proxy::TableReader); - REGISTER_PROXY(arrow.c.proxy.Array , arrow::matlab::c::proxy::Array); - REGISTER_PROXY(arrow.c.proxy.ArrayImporter , arrow::matlab::c::proxy::ArrayImporter); - REGISTER_PROXY(arrow.c.proxy.Schema , arrow::matlab::c::proxy::Schema); + REGISTER_PROXY(arrow.array.proxy.Float32Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Float64Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.UInt8Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.UInt16Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.UInt32Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.UInt64Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Int8Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Int16Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Int32Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Int64Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.BooleanArray , arrow::matlab::array::proxy::BooleanArray); + REGISTER_PROXY(arrow.array.proxy.StringArray , arrow::matlab::array::proxy::StringArray); + REGISTER_PROXY(arrow.array.proxy.StructArray , arrow::matlab::array::proxy::StructArray); + REGISTER_PROXY(arrow.array.proxy.ListArray , arrow::matlab::array::proxy::ListArray); + REGISTER_PROXY(arrow.array.proxy.TimestampArray , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Time32Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Time64Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Date32Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.Date64Array , arrow::matlab::array::proxy::NumericArray); + REGISTER_PROXY(arrow.array.proxy.ChunkedArray , arrow::matlab::array::proxy::ChunkedArray); + REGISTER_PROXY(arrow.buffer.proxy.Buffer , arrow::matlab::buffer::proxy::Buffer); + REGISTER_PROXY(arrow.tabular.proxy.RecordBatch , arrow::matlab::tabular::proxy::RecordBatch); + REGISTER_PROXY(arrow.tabular.proxy.Table , arrow::matlab::tabular::proxy::Table); + REGISTER_PROXY(arrow.tabular.proxy.Schema , arrow::matlab::tabular::proxy::Schema); + REGISTER_PROXY(arrow.type.proxy.Field , arrow::matlab::type::proxy::Field); + REGISTER_PROXY(arrow.type.proxy.Float32Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.Float64Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.UInt8Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.UInt16Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.UInt32Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.UInt64Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.Int8Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.Int16Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.Int32Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.Int64Type , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.BooleanType , arrow::matlab::type::proxy::PrimitiveCType); + REGISTER_PROXY(arrow.type.proxy.StringType , arrow::matlab::type::proxy::StringType); + REGISTER_PROXY(arrow.type.proxy.TimestampType , arrow::matlab::type::proxy::TimestampType); + REGISTER_PROXY(arrow.type.proxy.Time32Type , arrow::matlab::type::proxy::Time32Type); + REGISTER_PROXY(arrow.type.proxy.Time64Type , arrow::matlab::type::proxy::Time64Type); + REGISTER_PROXY(arrow.type.proxy.Date32Type , arrow::matlab::type::proxy::Date32Type); + REGISTER_PROXY(arrow.type.proxy.Date64Type , arrow::matlab::type::proxy::Date64Type); + REGISTER_PROXY(arrow.type.proxy.StructType , arrow::matlab::type::proxy::StructType); + REGISTER_PROXY(arrow.type.proxy.ListType , arrow::matlab::type::proxy::ListType); + REGISTER_PROXY(arrow.io.feather.proxy.Writer , arrow::matlab::io::feather::proxy::Writer); + REGISTER_PROXY(arrow.io.feather.proxy.Reader , arrow::matlab::io::feather::proxy::Reader); + REGISTER_PROXY(arrow.io.csv.proxy.TableWriter , arrow::matlab::io::csv::proxy::TableWriter); + REGISTER_PROXY(arrow.io.csv.proxy.TableReader , arrow::matlab::io::csv::proxy::TableReader); + REGISTER_PROXY(arrow.c.proxy.Array , arrow::matlab::c::proxy::Array); + REGISTER_PROXY(arrow.c.proxy.ArrayImporter , arrow::matlab::c::proxy::ArrayImporter); + REGISTER_PROXY(arrow.c.proxy.Schema , arrow::matlab::c::proxy::Schema); + REGISTER_PROXY(arrow.c.proxy.RecordBatchImporter , arrow::matlab::c::proxy::RecordBatchImporter); // clang-format on return libmexclass::error::Error{error::UNKNOWN_PROXY_ERROR_ID, diff --git a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc index 298ac4b595139..f3cee25a3a8ee 100644 --- a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc +++ b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "libmexclass/proxy/ProxyManager.h" - +#include "arrow/c/bridge.h" #include "arrow/matlab/array/proxy/array.h" #include "arrow/matlab/array/proxy/wrap.h" @@ -66,6 +65,7 @@ RecordBatch::RecordBatch(std::shared_ptr record_batch) REGISTER_METHOD(RecordBatch, getColumnByName); REGISTER_METHOD(RecordBatch, getSchema); REGISTER_METHOD(RecordBatch, getRowAsString); + REGISTER_METHOD(RecordBatch, exportToC); } std::shared_ptr RecordBatch::unwrap() { return record_batch; } @@ -259,4 +259,19 @@ void RecordBatch::getRowAsString(libmexclass::proxy::method::Context& context) { context.outputs[0] = factory.createScalar(row_str_utf16); } +void RecordBatch::exportToC(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + mda::StructArray opts = context.inputs[0]; + const mda::TypedArray array_address_mda = opts[0]["ArrowArrayAddress"]; + const mda::TypedArray schema_address_mda = opts[0]["ArrowSchemaAddress"]; + + auto arrow_array = reinterpret_cast(uint64_t(array_address_mda[0])); + auto arrow_schema = + reinterpret_cast(uint64_t(schema_address_mda[0])); + + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT( + arrow::ExportRecordBatch(*record_batch, arrow_array, arrow_schema), context, + error::C_EXPORT_FAILED); +} + } // namespace arrow::matlab::tabular::proxy diff --git a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h index c8285c9b095d5..4a1675a8a438a 100644 --- a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h +++ b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h @@ -43,6 +43,7 @@ class RecordBatch : public libmexclass::proxy::Proxy { void getColumnByName(libmexclass::proxy::method::Context& context); void getSchema(libmexclass::proxy::method::Context& context); void getRowAsString(libmexclass::proxy::method::Context& context); + void exportToC(libmexclass::proxy::method::Context& context); std::shared_ptr record_batch; }; diff --git a/matlab/src/matlab/+arrow/+c/+internal/RecordBatchImporter.m b/matlab/src/matlab/+arrow/+c/+internal/RecordBatchImporter.m new file mode 100644 index 0000000000000..120763bb46e7b --- /dev/null +++ b/matlab/src/matlab/+arrow/+c/+internal/RecordBatchImporter.m @@ -0,0 +1,52 @@ +%RECORDBATCHIMPORTER Imports Arrow RecordBatch using the C Data Interface +% Format. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef RecordBatchImporter + + properties (Hidden, SetAccess=private, GetAccess=public) + Proxy + end + + methods + + function obj = RecordBatchImporter() + proxyName = "arrow.c.proxy.RecordBatchImporter"; + proxy = arrow.internal.proxy.create(proxyName, struct()); + obj.Proxy = proxy; + end + + function recordBatch = import(obj, cArray, cSchema) + arguments + obj(1, 1) arrow.c.internal.RecordBatchImporter + cArray(1, 1) arrow.c.Array + cSchema(1, 1) arrow.c.Schema + end + args = struct(... + ArrowArrayAddress=cArray.Address,... + ArrowSchemaAddress=cSchema.Address... + ); + proxyID = obj.Proxy.import(args); + proxyName = "arrow.tabular.proxy.RecordBatch"; + proxy = libmexclass.proxy.Proxy(Name=proxyName, ID=proxyID); + recordBatch = arrow.tabular.RecordBatch(proxy); + end + + end + +end + diff --git a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m index 0225f3d771181..da5c1fc1c3764 100644 --- a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m +++ b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m @@ -102,6 +102,19 @@ function tf = isequal(obj, varargin) tf = arrow.tabular.internal.isequal(obj, varargin{:}); end + + function export(obj, cArrowArrayAddress, cArrowSchemaAddress) + arguments + obj(1, 1) arrow.tabular.RecordBatch + cArrowArrayAddress(1, 1) uint64 + cArrowSchemaAddress(1, 1) uint64 + end + args = struct(... + ArrowArrayAddress=cArrowArrayAddress,... + ArrowSchemaAddress=cArrowSchemaAddress... + ); + obj.Proxy.exportToC(args); + end end methods (Access = private) @@ -141,5 +154,14 @@ function displayScalarObject(obj) proxy = arrow.internal.proxy.create(proxyName, args); recordBatch = arrow.tabular.RecordBatch(proxy); end + + function recordBatch = import(cArray, cSchema) + arguments + cArray(1, 1) arrow.c.Array + cSchema(1, 1) arrow.c.Schema + end + importer = arrow.c.internal.RecordBatchImporter(); + recordBatch = importer.import(cArray, cSchema); + end end end diff --git a/matlab/test/arrow/c/tRoundTripRecordBatch.m b/matlab/test/arrow/c/tRoundTripRecordBatch.m new file mode 100644 index 0000000000000..5d95aecbe1603 --- /dev/null +++ b/matlab/test/arrow/c/tRoundTripRecordBatch.m @@ -0,0 +1,170 @@ +%TROUNDTRIPRECORDBATCH Tests for roundtripping RecordBatches using +% the C Data Interface format. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. +classdef tRoundTripRecordBatch < matlab.unittest.TestCase + + methods (Test) + function ZeroColumnRecordBatch(testCase) + expected = arrow.recordBatch(table()); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + expected.export(cArray.Address, cSchema.Address); + actual = arrow.tabular.RecordBatch.import(cArray, cSchema); + + testCase.verifyEqual(actual, expected); + end + + function ZeroRowRecordBatch(testCase) + doubleArray = arrow.array([]); + stringArray = arrow.array(string.empty(0, 0)); + expected = arrow.tabular.RecordBatch.fromArrays(doubleArray, stringArray); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + expected.export(cArray.Address, cSchema.Address); + actual = arrow.tabular.RecordBatch.import(cArray, cSchema); + + testCase.verifyEqual(actual, expected); + end + + function OneRowRecordBatch(testCase) + varNames = ["Col1" "Col2" "Col3"]; + t = table(1, "A", false, VariableNames=varNames); + expected = arrow.recordBatch(t); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + expected.export(cArray.Address, cSchema.Address); + actual = arrow.tabular.RecordBatch.import(cArray, cSchema); + + testCase.verifyEqual(actual, expected); + end + + function MultiRowRecordBatch(testCase) + varNames = ["Col1" "Col2" "Col3"]; + t = table((1:3)', ["A"; "B"; "C"], [false; true; false],... + VariableNames=varNames); + expected = arrow.recordBatch(t); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + expected.export(cArray.Address, cSchema.Address); + actual = arrow.tabular.RecordBatch.import(cArray, cSchema); + + testCase.verifyEqual(actual, expected); + end + + function ExportErrorWrongInputTypes(testCase) + rb = arrow.recordBatch(table([1; 2; 3])); + fcn = @() rb.export("cArray.Address", "cSchema.Address"); + testCase.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + end + + function ExportTooFewInputs(testCase) + rb = arrow.recordBatch(table([1; 2; 3])); + fcn = @() rb.export(); + testCase.verifyError(fcn, "MATLAB:minrhs"); + end + + function ExportTooManyInputs(testCase) + rb = arrow.recordBatch(table([1; 2; 3])); + fcn = @() rb.export("A", "B", "C"); + testCase.verifyError(fcn, "MATLAB:TooManyInputs"); + end + + function ImportErrorWrongInputTypes(testCase) + cArray = "arrow.c.Array"; + cSchema = "arrow.c.Schema"; + fcn = @() arrow.tabular.RecordBatch.import(cArray, cSchema); + testCase.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + end + + function ImportTooFewInputs(testCase) + fcn = @() arrow.tabular.RecordBatch.import(); + testCase.verifyError(fcn, "MATLAB:minrhs"); + end + + function ImportTooManyInputs(testCase) + fcn = @() arrow.tabular.RecordBatch.import("A", "B", "C"); + testCase.verifyError(fcn, "MATLAB:TooManyInputs"); + end + + function ImportErrorImportFailed(testCase) + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + % An arrow:c:import:ImportFailed error should be thrown + % if the supplied arrow.c.Array and arrow.c.Schema were + % never populated previously from an exported Array. + fcn = @() arrow.tabular.RecordBatch.import(cArray, cSchema); + testCase.verifyError(fcn, "arrow:c:import:ImportFailed"); + end + + function ImportErrorInvalidSchema(testCase) + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + % An arrow:c:import:ImportFailed error should be thrown + % if the supplied arrow.c.Schema was not populated from a + % struct-like type (i.e. StructArray or RecordBatch). + a = arrow.array(1:3); + a.export(cArray.Address, cSchema.Address); + fcn = @() arrow.tabular.RecordBatch.import(cArray, cSchema); + testCase.verifyError(fcn, "arrow:c:import:ImportFailed"); + end + + function ImportFromStructArray(testCase) + % Verify a StructArray exported via the C Data Interface format + % can be imported as a RecordBatch. + field1 = arrow.array(1:3); + + field2 = arrow.array(["A" "B" "C"]); + structArray = arrow.array.StructArray.fromArrays(field1, field2, ... + FieldNames=["Number" "Text"]); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + structArray.export(cArray.Address, cSchema.Address) + rb = arrow.tabular.RecordBatch.import(cArray, cSchema); + + expected = arrow.tabular.RecordBatch.fromArrays(field1, field2, ... + ColumnNames=["Number" "Text"]); + + testCase.verifyEqual(rb, expected); + end + + function ExportToStructArray(testCase) + % Verify a RecordBatch exported via the C Data Interface + % format can be imported as a StructArray. + column1 = arrow.array(1:3); + column2 = arrow.array(["A" "B" "C"]); + rb = arrow.tabular.RecordBatch.fromArrays(column1, column2, ... + ColumnNames=["Number" "Text"]); + + cArray = arrow.c.Array(); + cSchema = arrow.c.Schema(); + rb.export(cArray.Address, cSchema.Address) + structArray = arrow.array.Array.import(cArray, cSchema); + + expected = arrow.array.StructArray.fromArrays(column1, column2, ... + FieldNames=["Number" "Text"]); + + testCase.verifyEqual(structArray, expected); + end + + end + +end \ No newline at end of file diff --git a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake index 92e9f59145acc..0a747e648cd84 100644 --- a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake +++ b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake @@ -78,7 +78,8 @@ set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_SOURCES "${CMAKE_SOURCE_DIR}/src/cpp/a "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/buffer/proxy/buffer.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/c/proxy/array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/c/proxy/array_importer.cc" - "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/c/proxy/schema.cc") + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/c/proxy/schema.cc" + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/c/proxy/record_batch_importer.cc") set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_FACTORY_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/proxy")