Skip to content

Commit 8515a08

Browse files
apacheGH-46877: [MATLAB] Add arrow.tabular.Table.fromRecordBatches static method (apache#46885)
### Rationale for this change This change makes it possible to create an `arrow.tabular.Table` instance from a list of `arrow.tabular.RecordBatch` instances whose `Schema`s are consistent. ### What changes are included in this PR? Added a new static method called `fromRecordBatches` to the MATLAB class `arrow.tabular.Table`. This method should construct an `arrow.tabular.Table` from a variable number of `arrow.tabular.RecordBatch`es. **Usage Example** ```matlab >> rb1 = arrow.recordBatch(table([1:5]', [6:10]')); >> rb2 = arrow.recordBatch(table([11:15]', [16:20]')); >> table = arrow.tabular.Table.fromRecordBatches(rb1, rb2) table = Arrow Table with 10 rows and 2 columns: Schema: Var1: Float64 | Var2: Float64 First Row: 1 | 6 ``` **Error Message Examples** ```matlab % Error message when fromRecordBatches is called with zero input arguments >> arrow.tabular.Table.fromRecordBatches() Error using arrow.tabular.Table.fromRecordBatches (line 154) The fromRecordBatches method requires at least one RecordBatch to be supplied. % Error message when fromRecordBatches is given RecordBatches whose Schemas are inconsistent. >> rb1 = arrow.recordBatch(table(1, 2, VariableNames=["Num1", "Num2"])); >> rb2 = arrow.recordBatch(table(1, "A", VariableNames=["Num1", "Letter1"])); >> arrow.tabular.Table.fromRecordBatches(rb1, rb2) Error using arrow.tabular.Table.fromRecordBatches (line 167) All RecordBatches must have the same Schema. Schema of RecordBatch 2 is Num1: Float64 | Letter1: String Expected RecordBatch Schema to be Num1: Float64 | Num2: Float64 ``` ### Are these changes tested? Yes. Added four new test cases to the MATLAB test class `tTable`: 1. `FromRecordBatchesZeroInputsError` 2. `FromRecordBatchesOneInput` 3. `FromRecordBatchesMultipleInputs` 4. `FromRecordBatchesInconsistentSchemaError` ### Are there any user-facing changes? Yes. Users can now create an `arrow.tabular.Table` instance via the static method `fromRecordBatches`. * GitHub Issue: apache#46877 Lead-authored-by: Sarah Gilmore <[email protected]> Co-authored-by: Sarah Gilmore <[email protected]> Co-authored-by: Kevin Gurney <[email protected]> Signed-off-by: Sarah Gilmore <[email protected]>
1 parent dacec30 commit 8515a08

File tree

5 files changed

+149
-16
lines changed

5 files changed

+149
-16
lines changed

matlab/src/cpp/arrow/matlab/error/error.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,6 @@ static const char* IPC_RECORD_BATCH_READ_INVALID_INDEX = "arrow:io:ipc:InvalidIn
253253
static const char* IPC_RECORD_BATCH_READ_FAILED = "arrow:io:ipc:ReadFailed";
254254
static const char* IPC_TABLE_READ_FAILED = "arrow:io:ipc:TableReadFailed";
255255
static const char* IPC_END_OF_STREAM = "arrow:io:ipc:EndOfStream";
256+
static const char* TABLE_MAKE_UNKNOWN_METHOD = "arrow:table:UnknownMakeMethod";
256257

257258
} // namespace arrow::matlab::error

matlab/src/cpp/arrow/matlab/tabular/proxy/table.cc

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
#include "libmexclass/proxy/ProxyManager.h"
1919

2020
#include "arrow/matlab/array/proxy/array.h"
21+
2122
#include "arrow/matlab/array/proxy/chunked_array.h"
2223
#include "arrow/matlab/array/proxy/wrap.h"
2324

2425
#include "arrow/matlab/error/error.h"
2526
#include "arrow/matlab/tabular/get_row_as_string.h"
27+
#include "arrow/matlab/tabular/proxy/record_batch.h"
2628
#include "arrow/matlab/tabular/proxy/schema.h"
2729
#include "arrow/matlab/tabular/proxy/table.h"
2830

@@ -34,6 +36,8 @@
3436

3537
namespace arrow::matlab::tabular::proxy {
3638

39+
namespace mda = ::matlab::data;
40+
3741
namespace {
3842
libmexclass::error::Error makeEmptyTableError() {
3943
const std::string error_msg =
@@ -70,7 +74,6 @@ Table::Table(std::shared_ptr<arrow::Table> table) : table{table} {
7074
std::shared_ptr<arrow::Table> Table::unwrap() { return table; }
7175

7276
void Table::toString(libmexclass::proxy::method::Context& context) {
73-
namespace mda = ::matlab::data;
7477
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto utf16_string,
7578
arrow::util::UTF8StringToUTF16(table->ToString()),
7679
context, error::UNICODE_CONVERSION_ERROR_ID);
@@ -79,12 +82,11 @@ void Table::toString(libmexclass::proxy::method::Context& context) {
7982
context.outputs[0] = str_mda;
8083
}
8184

82-
libmexclass::proxy::MakeResult Table::make(
83-
const libmexclass::proxy::FunctionArguments& constructor_arguments) {
85+
namespace {
86+
libmexclass::proxy::MakeResult from_arrays(const mda::StructArray& opts) {
8487
using ArrayProxy = arrow::matlab::array::proxy::Array;
8588
using TableProxy = arrow::matlab::tabular::proxy::Table;
86-
namespace mda = ::matlab::data;
87-
mda::StructArray opts = constructor_arguments[0];
89+
8890
const mda::TypedArray<uint64_t> arrow_array_proxy_ids = opts[0]["ArrayProxyIDs"];
8991
const mda::StringArray column_names = opts[0]["ColumnNames"];
9092

@@ -114,9 +116,64 @@ libmexclass::proxy::MakeResult Table::make(
114116
error::SCHEMA_BUILDER_FINISH_ERROR_ID);
115117
const auto num_rows = arrow_arrays.size() == 0 ? 0 : arrow_arrays[0]->length();
116118
const auto table = arrow::Table::Make(schema, arrow_arrays, num_rows);
117-
auto table_proxy = std::make_shared<TableProxy>(table);
119+
return std::make_shared<TableProxy>(table);
120+
}
121+
122+
libmexclass::proxy::MakeResult from_record_batches(const mda::StructArray& opts) {
123+
using RecordBatchProxy = arrow::matlab::tabular::proxy::RecordBatch;
124+
using TableProxy = arrow::matlab::tabular::proxy::Table;
118125

119-
return table_proxy;
126+
size_t num_rows = 0;
127+
const mda::TypedArray<uint64_t> record_batch_proxy_ids = opts[0]["RecordBatchProxyIDs"];
128+
129+
std::vector<std::shared_ptr<arrow::RecordBatch>> record_batches;
130+
// Retrieve all of the Arrow RecordBatch Proxy instances from the libmexclass
131+
// ProxyManager.
132+
for (const auto& proxy_id : record_batch_proxy_ids) {
133+
auto proxy = libmexclass::proxy::ProxyManager::getProxy(proxy_id);
134+
auto record_batch_proxy = std::static_pointer_cast<RecordBatch>(proxy);
135+
auto record_batch = record_batch_proxy->unwrap();
136+
record_batches.push_back(record_batch);
137+
num_rows += record_batches.back()->num_rows();
138+
}
139+
140+
// The MATLAB client code that calls this function is responsible for pre-validating
141+
// that this function is called with at least one RecordBatch.
142+
auto schema = record_batches[0]->schema();
143+
size_t num_columns = schema->num_fields();
144+
std::vector<std::shared_ptr<ChunkedArray>> columns(num_columns);
145+
146+
size_t num_batches = record_batches.size();
147+
148+
for (size_t i = 0; i < num_columns; ++i) {
149+
std::vector<std::shared_ptr<Array>> column_arrays(num_batches);
150+
for (size_t j = 0; j < num_batches; ++j) {
151+
column_arrays[j] = record_batches[j]->column(i);
152+
}
153+
columns[i] = std::make_shared<ChunkedArray>(column_arrays, schema->field(i)->type());
154+
}
155+
const auto table = arrow::Table::Make(std::move(schema), std::move(columns), num_rows);
156+
return std::make_shared<TableProxy>(table);
157+
}
158+
} // anonymous namespace
159+
160+
libmexclass::proxy::MakeResult Table::make(
161+
const libmexclass::proxy::FunctionArguments& constructor_arguments) {
162+
mda::StructArray opts = constructor_arguments[0];
163+
const mda::StringArray method = opts[0]["Method"];
164+
165+
if (method[0] == u"from_arrays") {
166+
return from_arrays(opts);
167+
} else if (method[0] == u"from_record_batches") {
168+
return from_record_batches(opts);
169+
} else {
170+
const auto method_name_utf16 = std::u16string(method[0]);
171+
MATLAB_ASSIGN_OR_ERROR(const auto method_name_utf8,
172+
arrow::util::UTF16StringToUTF8(method_name_utf16),
173+
error::UNICODE_CONVERSION_ERROR_ID);
174+
const std::string error_msg = "Unknown make method: " + method_name_utf8;
175+
return libmexclass::error::Error{error::TABLE_MAKE_UNKNOWN_METHOD, error_msg};
176+
}
120177
}
121178

122179
void Table::getNumRows(libmexclass::proxy::method::Context& context) {

matlab/src/matlab/+arrow/+tabular/Table.m

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,38 @@ function displayScalarObject(obj)
139139
validateColumnNames(opts.ColumnNames, numColumns);
140140

141141
arrayProxyIDs = getArrayProxyIDs(arrowArrays);
142-
args = struct(ArrayProxyIDs=arrayProxyIDs, ColumnNames=opts.ColumnNames);
142+
args = struct(Method="from_arrays", ArrayProxyIDs=arrayProxyIDs, ColumnNames=opts.ColumnNames);
143+
proxyName = "arrow.tabular.proxy.Table";
144+
proxy = arrow.internal.proxy.create(proxyName, args);
145+
arrowTable = arrow.tabular.Table(proxy);
146+
end
147+
148+
function arrowTable = fromRecordBatches(batches)
149+
arguments(Repeating)
150+
batches(1, 1) arrow.tabular.RecordBatch
151+
end
152+
if numel(batches) == 0
153+
msg = "The fromRecordBatches method requires at least one RecordBatch to be supplied.";
154+
error("arrow:Table:FromRecordBatches:ZeroBatches", msg);
155+
elseif numel(batches) > 1
156+
% Verify that all supplied RecordBatches have a consistent Schema.
157+
firstSchema = batches{1}.Schema;
158+
otherSchemas = cellfun(@(rb) rb.Schema, batches(2:end), UniformOutput=false);
159+
idx = cellfun(@(other) ~isequal(firstSchema, other), otherSchemas, UniformOutput=true);
160+
inconsistentSchemaIndex = find(idx, 1,"first");
161+
if ~isempty(inconsistentSchemaIndex)
162+
inconsistentSchemaIndex = inconsistentSchemaIndex + 1;
163+
expectedSchema = arrow.tabular.internal.display.getSchemaString(firstSchema);
164+
inconsistentSchema = arrow.tabular.internal.display.getSchemaString(batches{inconsistentSchemaIndex}.Schema);
165+
msg = "All RecordBatches must have the same Schema.\n\nSchema of RecordBatch %d is\n\n\t%s\n\nExpected RecordBatch Schema to be\n\n\t%s";
166+
msg = compose(msg, inconsistentSchemaIndex, inconsistentSchema, expectedSchema);
167+
error("arrow:Table:FromRecordBatches:InconsistentSchema", msg);
168+
end
169+
end
170+
171+
% TODO: Rename getArrayProxyIDs to getProxyIDs
172+
proxyIDs = arrow.array.internal.getArrayProxyIDs(batches);
173+
args = struct(Method="from_record_batches", RecordBatchProxyIDs=proxyIDs);
143174
proxyName = "arrow.tabular.proxy.Table";
144175
proxy = arrow.internal.proxy.create(proxyName, args);
145176
arrowTable = arrow.tabular.Table(proxy);

matlab/src/matlab/+arrow/table.m

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,7 @@
2020
% ambiguous name parsing issue with MATLAB table type and arrow.table.
2121
matlabTable {istable} = table.empty(0, 0)
2222
end
23-
2423
arrowArrays = arrow.tabular.internal.decompose(matlabTable);
25-
arrayProxyIDs = arrow.array.internal.getArrayProxyIDs(arrowArrays);
26-
2724
columnNames = string(matlabTable.Properties.VariableNames);
28-
args = struct(ArrayProxyIDs=arrayProxyIDs, ColumnNames=columnNames);
29-
proxyName = "arrow.tabular.proxy.Table";
30-
proxy = arrow.internal.proxy.create(proxyName, args);
31-
32-
arrowTable = arrow.tabular.Table(proxy);
25+
arrowTable = arrow.tabular.Table.fromArrays(arrowArrays{:}, ColumnNames=columnNames);
3326
end

matlab/test/arrow/tabular/tTable.m

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,57 @@ function TestIsEqualFalse(testCase)
664664
testCase.verifyFalse(isequal(t1, t2, t3, t4));
665665
end
666666

667+
function FromRecordBatchesZeroInputsError(testCase)
668+
% Verify the arrow.tabular.Table.fromRecordBatches function
669+
% throws an `arrow:Table:FromRecordBatches:ZeroBatches`
670+
% exception if called with zero input arguments.
671+
import arrow.tabular.Table
672+
fcn = @() Table.fromRecordBatches();
673+
testCase.verifyError(fcn, "arrow:Table:FromRecordBatches:ZeroBatches");
674+
end
675+
676+
function FromRecordBatchesOneInput(testCase)
677+
% Verify the arrow.tabular.Table.fromRecordBatches function
678+
% returns the expected arrow.tabular.Table instance when
679+
% provided a single RecordBatch as input.
680+
import arrow.tabular.Table
681+
matlabTable = table([1; 2], ["A"; "B"], VariableNames=["Number" "Letter"]);
682+
recordBatch = arrow.recordBatch(matlabTable);
683+
arrowTable = Table.fromRecordBatches(recordBatch);
684+
testCase.verifyTable(arrowTable, ["Number", "Letter"], ["arrow.type.Float64Type", "arrow.type.StringType"], matlabTable);
685+
end
686+
687+
function FromRecordBatchesMultipleInputs(testCase)
688+
% Verify the arrow.tabular.Table.fromRecordBatches function
689+
% returns the expected arrow.tabular.Table instance when
690+
% provided mulitple RecordBatches as input.
691+
import arrow.tabular.Table
692+
matlabTable1 = table([1; 2], ["A"; "B"], VariableNames=["Number" "Letter"]);
693+
matlabTable2 = table([10; 20; 30], ["A1"; "B1"; "C1"], VariableNames=["Number" "Letter"]);
694+
matlabTable3 = table([100; 200], ["A2"; "B2"], VariableNames=["Number" "Letter"]);
695+
696+
recordBatch1 = arrow.recordBatch(matlabTable1);
697+
recordBatch2 = arrow.recordBatch(matlabTable2);
698+
recordBatch3 = arrow.recordBatch(matlabTable3);
699+
700+
arrowTable = Table.fromRecordBatches(recordBatch1, recordBatch2, recordBatch3);
701+
testCase.verifyTable(arrowTable, ["Number", "Letter"], ["arrow.type.Float64Type", "arrow.type.StringType"], [matlabTable1; matlabTable2; matlabTable3]);
702+
end
703+
704+
function FromRecordBatchesInconsistentSchemaError(testCase)
705+
% Verify the arrow.tabular.Table.fromRecordBatches function
706+
% throws an `arrow:Table:FromRecordBatches:InconsistentSchema`
707+
% exception if the Schemas of the provided RecordBatches are
708+
% inconsistent.
709+
import arrow.tabular.Table
710+
matlabTable1 = table("A", 1);
711+
matlabTable2 = table(2, "B");
712+
recordBatch1 = arrow.recordBatch(matlabTable1);
713+
recordBatch2 = arrow.recordBatch(matlabTable2);
714+
715+
fcn = @() Table.fromRecordBatches(recordBatch1, recordBatch2);
716+
testCase.verifyError(fcn, "arrow:Table:FromRecordBatches:InconsistentSchema");
717+
end
667718
end
668719

669720
methods

0 commit comments

Comments
 (0)