Skip to content

Commit 6e438e6

Browse files
adamreevepitrou
andauthored
apacheGH-43142: [C++][Parquet] Refactor Encryptor API to use arrow::util::span instead of raw pointers (apache#43195)
### Rationale for this change See apache#43142. This is a follow up to apache#43071 which refactored the Decryptor API and added extra checks to prevent segfaults. This PR makes similar changes to the Encryptor API for consistency and better maintainability. ### What changes are included in this PR? * Change `AesEncryptor::Encrypt` and `Encryptor::Encrypt` to use `arrow::util::span` instead of raw pointers * Replace the `AesEncryptor::CiphertextSizeDelta` method with a `CiphertextLength` method that checks for overflow and abstracts the size difference behaviour away from consumer code for improved readability. ### Are these changes tested? * This is mostly a refactoring of existing code so is covered by existing tests. ### Are there any user-facing changes? No * GitHub Issue: apache#43142 Lead-authored-by: Adam Reeve <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent c777ac8 commit 6e438e6

10 files changed

+218
-164
lines changed

cpp/src/parquet/column_writer.cc

+8-7
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,10 @@ class SerializedPageWriter : public PageWriter {
303303
if (data_encryptor_.get()) {
304304
UpdateEncryption(encryption::kDictionaryPage);
305305
PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
306-
data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
307-
output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len,
308-
encryption_buffer_->mutable_data());
306+
data_encryptor_->CiphertextLength(output_data_len), false));
307+
output_data_len =
308+
data_encryptor_->Encrypt(compressed_data->span_as<uint8_t>(),
309+
encryption_buffer_->mutable_span_as<uint8_t>());
309310
output_data_buffer = encryption_buffer_->data();
310311
}
311312

@@ -395,11 +396,11 @@ class SerializedPageWriter : public PageWriter {
395396

396397
if (data_encryptor_.get()) {
397398
PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
398-
data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
399+
data_encryptor_->CiphertextLength(output_data_len), false));
399400
UpdateEncryption(encryption::kDataPage);
400-
output_data_len = data_encryptor_->Encrypt(compressed_data->data(),
401-
static_cast<int32_t>(output_data_len),
402-
encryption_buffer_->mutable_data());
401+
output_data_len =
402+
data_encryptor_->Encrypt(compressed_data->span_as<uint8_t>(),
403+
encryption_buffer_->mutable_span_as<uint8_t>());
403404
output_data_buffer = encryption_buffer_->data();
404405
}
405406

cpp/src/parquet/encryption/encryption_internal.cc

+140-91
Large diffs are not rendered by default.

cpp/src/parquet/encryption/encryption_internal.h

+11-7
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,22 @@ class PARQUET_EXPORT AesEncryptor {
6161

6262
~AesEncryptor();
6363

64-
/// Size difference between plaintext and ciphertext, for this cipher.
65-
int CiphertextSizeDelta();
64+
/// The size of the ciphertext, for this cipher and the specified plaintext length.
65+
[[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const;
6666

6767
/// Encrypts plaintext with the key and aad. Key length is passed only for validation.
6868
/// If different from value in constructor, exception will be thrown.
69-
int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
70-
int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext);
69+
int Encrypt(::arrow::util::span<const uint8_t> plaintext,
70+
::arrow::util::span<const uint8_t> key,
71+
::arrow::util::span<const uint8_t> aad,
72+
::arrow::util::span<uint8_t> ciphertext);
7173

7274
/// Encrypts plaintext footer, in order to compute footer signature (tag).
73-
int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key,
74-
int key_len, const uint8_t* aad, int aad_len,
75-
const uint8_t* nonce, uint8_t* encrypted_footer);
75+
int SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
76+
::arrow::util::span<const uint8_t> key,
77+
::arrow::util::span<const uint8_t> aad,
78+
::arrow::util::span<const uint8_t> nonce,
79+
::arrow::util::span<uint8_t> encrypted_footer);
7680

7781
void WipeOut();
7882

cpp/src/parquet/encryption/encryption_internal_nossl.cc

+10-8
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,26 @@ class AesEncryptor::AesEncryptorImpl {};
2929

3030
AesEncryptor::~AesEncryptor() {}
3131

32-
int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len,
33-
const uint8_t* key, int key_len, const uint8_t* aad,
34-
int aad_len, const uint8_t* nonce,
35-
uint8_t* encrypted_footer) {
32+
int AesEncryptor::SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
33+
::arrow::util::span<const uint8_t> key,
34+
::arrow::util::span<const uint8_t> aad,
35+
::arrow::util::span<const uint8_t> nonce,
36+
::arrow::util::span<uint8_t> encrypted_footer) {
3637
ThrowOpenSSLRequiredException();
3738
return -1;
3839
}
3940

4041
void AesEncryptor::WipeOut() { ThrowOpenSSLRequiredException(); }
4142

42-
int AesEncryptor::CiphertextSizeDelta() {
43+
int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const {
4344
ThrowOpenSSLRequiredException();
4445
return -1;
4546
}
4647

47-
int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
48-
int key_len, const uint8_t* aad, int aad_len,
49-
uint8_t* ciphertext) {
48+
int AesEncryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
49+
::arrow::util::span<const uint8_t> key,
50+
::arrow::util::span<const uint8_t> aad,
51+
::arrow::util::span<uint8_t> ciphertext) {
5052
ThrowOpenSSLRequiredException();
5153
return -1;
5254
}

cpp/src/parquet/encryption/encryption_internal_test.cc

+8-12
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ class TestAesEncryption : public ::testing::Test {
3737

3838
AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length);
3939

40-
int expected_ciphertext_len =
41-
static_cast<int>(plain_text_.size()) + encryptor.CiphertextSizeDelta();
40+
int32_t expected_ciphertext_len =
41+
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
4242
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');
4343

44-
int ciphertext_length =
45-
encryptor.Encrypt(str2bytes(plain_text_), static_cast<int>(plain_text_.size()),
46-
str2bytes(key_), static_cast<int>(key_.size()), str2bytes(aad_),
47-
static_cast<int>(aad_.size()), ciphertext.data());
44+
int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
45+
str2span(aad_), ciphertext);
4846

4947
ASSERT_EQ(ciphertext_length, expected_ciphertext_len);
5048

@@ -87,14 +85,12 @@ class TestAesEncryption : public ::testing::Test {
8785

8886
AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length);
8987

90-
int expected_ciphertext_len =
91-
static_cast<int>(plain_text_.size()) + encryptor.CiphertextSizeDelta();
88+
int32_t expected_ciphertext_len =
89+
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
9290
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');
9391

94-
int ciphertext_length =
95-
encryptor.Encrypt(str2bytes(plain_text_), static_cast<int>(plain_text_.size()),
96-
str2bytes(key_), static_cast<int>(key_.size()), str2bytes(aad_),
97-
static_cast<int>(aad_.size()), ciphertext.data());
92+
int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
93+
str2span(aad_), ciphertext);
9894

9995
AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length);
10096

cpp/src/parquet/encryption/internal_file_encryptor.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ Encryptor::Encryptor(encryption::AesEncryptor* aes_encryptor, const std::string&
3131
aad_(aad),
3232
pool_(pool) {}
3333

34-
int Encryptor::CiphertextSizeDelta() { return aes_encryptor_->CiphertextSizeDelta(); }
34+
int32_t Encryptor::CiphertextLength(int64_t plaintext_len) const {
35+
return aes_encryptor_->CiphertextLength(plaintext_len);
36+
}
3537

36-
int Encryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext) {
37-
return aes_encryptor_->Encrypt(plaintext, plaintext_len, str2bytes(key_),
38-
static_cast<int>(key_.size()), str2bytes(aad_),
39-
static_cast<int>(aad_.size()), ciphertext);
38+
int Encryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
39+
::arrow::util::span<uint8_t> ciphertext) {
40+
return aes_encryptor_->Encrypt(plaintext, str2span(key_), str2span(aad_), ciphertext);
4041
}
4142

4243
// InternalFileEncryptor

cpp/src/parquet/encryption/internal_file_encryptor.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ class PARQUET_EXPORT Encryptor {
4343
void UpdateAad(const std::string& aad) { aad_ = aad; }
4444
::arrow::MemoryPool* pool() { return pool_; }
4545

46-
int CiphertextSizeDelta();
47-
int Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext);
46+
[[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const;
47+
48+
int Encrypt(::arrow::util::span<const uint8_t> plaintext,
49+
::arrow::util::span<uint8_t> ciphertext);
4850

4951
bool EncryptColumnMetaData(
5052
bool encrypted_footer,

cpp/src/parquet/encryption/key_toolkit_internal.cc

+7-8
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@ std::string EncryptKeyLocally(const std::string& key_bytes, const std::string& m
3232
static_cast<int>(master_key.size()), false,
3333
false /*write_length*/);
3434

35-
int encrypted_key_len =
36-
static_cast<int>(key_bytes.size()) + key_encryptor.CiphertextSizeDelta();
35+
int32_t encrypted_key_len =
36+
key_encryptor.CiphertextLength(static_cast<int64_t>(key_bytes.size()));
3737
std::string encrypted_key(encrypted_key_len, '\0');
38-
encrypted_key_len = key_encryptor.Encrypt(
39-
reinterpret_cast<const uint8_t*>(key_bytes.data()),
40-
static_cast<int>(key_bytes.size()),
41-
reinterpret_cast<const uint8_t*>(master_key.data()),
42-
static_cast<int>(master_key.size()), reinterpret_cast<const uint8_t*>(aad.data()),
43-
static_cast<int>(aad.size()), reinterpret_cast<uint8_t*>(&encrypted_key[0]));
38+
::arrow::util::span<uint8_t> encrypted_key_span(
39+
reinterpret_cast<uint8_t*>(&encrypted_key[0]), encrypted_key_len);
40+
41+
encrypted_key_len = key_encryptor.Encrypt(str2span(key_bytes), str2span(master_key),
42+
str2span(aad), encrypted_key_span);
4443

4544
return ::arrow::util::base64_encode(
4645
::std::string_view(encrypted_key.data(), encrypted_key_len));

cpp/src/parquet/metadata.cc

+17-17
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,13 @@ class FileMetaData::FileMetaDataImpl {
640640
uint32_t serialized_len = metadata_len_;
641641
ThriftSerializer serializer;
642642
serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
643+
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
644+
serialized_len);
643645

644646
// encrypt with nonce
645-
auto nonce = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature));
646-
auto tag = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature)) +
647-
encryption::kNonceLength;
647+
::arrow::util::span<const uint8_t> nonce(reinterpret_cast<const uint8_t*>(signature),
648+
encryption::kNonceLength);
649+
auto tag = reinterpret_cast<const uint8_t*>(signature) + encryption::kNonceLength;
648650

649651
std::string key = file_decryptor_->GetFooterKey();
650652
std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad());
@@ -653,13 +655,11 @@ class FileMetaData::FileMetaDataImpl {
653655
file_decryptor_->algorithm(), static_cast<int>(key.size()), true,
654656
false /*write_length*/, nullptr);
655657

656-
std::shared_ptr<Buffer> encrypted_buffer = std::static_pointer_cast<ResizableBuffer>(
657-
AllocateBuffer(file_decryptor_->pool(),
658-
aes_encryptor->CiphertextSizeDelta() + serialized_len));
658+
std::shared_ptr<Buffer> encrypted_buffer = AllocateBuffer(
659+
file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len));
659660
uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt(
660-
serialized_data, serialized_len, str2bytes(key), static_cast<int>(key.size()),
661-
str2bytes(aad), static_cast<int>(aad.size()), nonce,
662-
encrypted_buffer->mutable_data());
661+
serialized_data_span, str2span(key), str2span(aad), nonce,
662+
encrypted_buffer->mutable_span_as<uint8_t>());
663663
// Delete AES encryptor object. It was created only to verify the footer signature.
664664
aes_encryptor->WipeOut();
665665
delete aes_encryptor;
@@ -701,12 +701,12 @@ class FileMetaData::FileMetaDataImpl {
701701
uint8_t* serialized_data;
702702
uint32_t serialized_len;
703703
serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
704+
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
705+
serialized_len);
704706

705707
// encrypt the footer key
706-
std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
707-
serialized_len);
708-
unsigned encrypted_len =
709-
encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
708+
std::vector<uint8_t> encrypted_data(encryptor->CiphertextLength(serialized_len));
709+
int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data);
710710

711711
// write unencrypted footer
712712
PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len));
@@ -1559,11 +1559,11 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl {
15591559

15601560
serializer.SerializeToBuffer(&column_chunk_->meta_data, &serialized_len,
15611561
&serialized_data);
1562+
::arrow::util::span<const uint8_t> serialized_data_span(serialized_data,
1563+
serialized_len);
15621564

1563-
std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
1564-
serialized_len);
1565-
unsigned encrypted_len =
1566-
encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
1565+
std::vector<uint8_t> encrypted_data(encryptor->CiphertextLength(serialized_len));
1566+
int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data);
15671567

15681568
const char* temp =
15691569
const_cast<const char*>(reinterpret_cast<char*>(encrypted_data.data()));

cpp/src/parquet/thrift_internal.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@ class ThriftDeserializer {
417417
throw ParquetException(ss.str());
418418
}
419419
// decrypt
420-
auto decrypted_buffer = std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
421-
decryptor->pool(), decryptor->PlaintextLength(static_cast<int32_t>(clen))));
420+
auto decrypted_buffer = AllocateBuffer(
421+
decryptor->pool(), decryptor->PlaintextLength(static_cast<int32_t>(clen)));
422422
::arrow::util::span<const uint8_t> cipher_buf(buf, clen);
423423
uint32_t decrypted_buffer_len =
424424
decryptor->Decrypt(cipher_buf, decrypted_buffer->mutable_span_as<uint8_t>());
@@ -525,13 +525,13 @@ class ThriftSerializer {
525525
}
526526
}
527527

528-
int64_t SerializeEncryptedObj(ArrowOutputStream* out, uint8_t* out_buffer,
528+
int64_t SerializeEncryptedObj(ArrowOutputStream* out, const uint8_t* out_buffer,
529529
uint32_t out_length, Encryptor* encryptor) {
530-
auto cipher_buffer = std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
531-
encryptor->pool(),
532-
static_cast<int64_t>(encryptor->CiphertextSizeDelta() + out_length)));
530+
auto cipher_buffer =
531+
AllocateBuffer(encryptor->pool(), encryptor->CiphertextLength(out_length));
532+
::arrow::util::span<const uint8_t> out_span(out_buffer, out_length);
533533
int cipher_buffer_len =
534-
encryptor->Encrypt(out_buffer, out_length, cipher_buffer->mutable_data());
534+
encryptor->Encrypt(out_span, cipher_buffer->mutable_span_as<uint8_t>());
535535

536536
PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len));
537537
return static_cast<int64_t>(cipher_buffer_len);

0 commit comments

Comments
 (0)