|
18 | 18 | #include "arrow/io/file.h"
|
19 | 19 | #include "arrow/matlab/error/error.h"
|
20 | 20 | #include "arrow/matlab/io/ipc/proxy/record_batch_file_reader.h"
|
| 21 | +#include "arrow/matlab/tabular/proxy/record_batch.h" |
21 | 22 | #include "arrow/matlab/tabular/proxy/schema.h"
|
22 | 23 | #include "arrow/util/utf8.h"
|
23 | 24 |
|
24 | 25 | #include "libmexclass/proxy/ProxyManager.h"
|
25 | 26 |
|
26 | 27 | namespace arrow::matlab::io::ipc::proxy {
|
27 | 28 |
|
| 29 | +namespace { |
| 30 | + libmexclass::error::Error makeInvalidNumericIndexError(const int32_t matlab_index, |
| 31 | + const int32_t num_batches) { |
| 32 | + std::stringstream error_message_stream; |
| 33 | + error_message_stream << "Invalid record batch index: "; |
| 34 | + error_message_stream << matlab_index; |
| 35 | + error_message_stream << ". Record batch index must be between 1 and the number of record batches ("; |
| 36 | + error_message_stream << num_batches; |
| 37 | + error_message_stream << ")."; |
| 38 | + return libmexclass::error::Error{error::IPC_RECORD_BATCH_READ_INVALID_INDEX, error_message_stream.str()}; |
| 39 | + } |
| 40 | +} |
| 41 | + |
28 | 42 | RecordBatchFileReader::RecordBatchFileReader(const std::shared_ptr<arrow::ipc::RecordBatchFileReader> reader)
|
29 | 43 | : reader{std::move(reader)} {
|
30 | 44 | REGISTER_METHOD(RecordBatchFileReader, getNumRecordBatches);
|
31 | 45 | REGISTER_METHOD(RecordBatchFileReader, getSchema);
|
32 |
| - |
33 | 46 | }
|
34 | 47 |
|
35 | 48 | libmexclass::proxy::MakeResult RecordBatchFileReader::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) {
|
@@ -75,9 +88,33 @@ void RecordBatchFileReader::getSchema(libmexclass::proxy::method::Context& conte
|
75 | 88 | mda::ArrayFactory factory;
|
76 | 89 | const auto schema_proxy_id_mda = factory.createScalar(schema_proxy_id);
|
77 | 90 | context.outputs[0] = schema_proxy_id_mda;
|
78 |
| - |
79 | 91 | }
|
80 | 92 |
|
| 93 | +void RecordBatchFileReader::readRecordBatchAtIndex(libmexclass::proxy::method::Context& context) { |
| 94 | + namespace mda = ::matlab::data; |
| 95 | + using RecordBatchProxy = arrow::matlab::tabular::proxy::RecordBatch; |
| 96 | + |
| 97 | + mda::StructArray opts = context.inputs[0]; |
| 98 | + const mda::TypedArray<int32_t> matlab_index_mda = opts[0]["Index"]; |
| 99 | + |
| 100 | + const auto matlab_index = matlab_index_mda[0]; |
| 101 | + const auto num_record_batches = reader->num_record_batches(); |
| 102 | + if (matlab_index < 1 || matlab_index > num_record_batches) { |
| 103 | + context.error = makeInvalidNumericIndexError(matlab_index, num_record_batches); |
| 104 | + return; |
| 105 | + } |
| 106 | + const auto arrow_index = matlab_index - 1; |
| 107 | + |
| 108 | + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto record_batch, reader->ReadRecordBatch(arrow_index), |
| 109 | + context, error::IPC_RECORD_BATCH_READ_FAILED); |
| 110 | + |
| 111 | + auto record_batch_proxy = std::make_shared<RecordBatchProxy>(std::move(record_batch)); |
| 112 | + const auto record_batch_proxy_id = libmexclass::proxy::ProxyManager::manageProxy(record_batch_proxy); |
| 113 | + |
| 114 | + mda::ArrayFactory factory; |
| 115 | + const auto record_batch_proxyy_id_mda = factory.createScalar(record_batch_proxy_id); |
| 116 | + context.outputs[0] = record_batch_proxyy_id_mda; |
| 117 | +} |
81 | 118 |
|
82 | 119 |
|
83 | 120 |
|
|
0 commit comments