Skip to content

Commit

Permalink
Ensure all buffers are properly aligned for AOT code.
Browse files Browse the repository at this point in the history
Previously we didn't really ensure that alignment requirements were met which could cause issues. This corrects that oversight.

PiperOrigin-RevId: 591014171
  • Loading branch information
allight authored and copybara-github committed Dec 14, 2023
1 parent 107faff commit d361acb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
1 change: 1 addition & 0 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_binary(
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//xls/common:init_xls",
"//xls/common/file:filesystem",
"//xls/common/logging",
Expand Down
14 changes: 10 additions & 4 deletions xls/jit/aot_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
Expand Down Expand Up @@ -181,14 +182,14 @@ const xls::aot_compile::FunctionTypeLayout& GetFunctionTypeLayout() {
absl::StatusOr<::xls::Value> {{wrapper_fn_name}}({{wrapper_params}}) {
{{arg_buffer_decls}}
uint8_t* arg_buffers[] = {{arg_buffer_collector}};
uint8_t result_buffer[{{result_size}}];
alignas({{result_buffer_align}}) uint8_t result_buffer[{{result_size}}];
GetFunctionTypeLayout().ArgValuesToNativeLayout(
{{{param_names}}}, absl::MakeSpan(arg_buffers, {{arg_count}}));
uint8_t* output_buffers[1] = {result_buffer};
std::vector<uint8_t> temp_buffers({{temp_buffer_size}});
alignas({{temp_buffer_align}}) uint8_t temp_buffer[{{temp_buffer_size}}];
::xls::InterpreterEvents events;
{{extern_fn}}(arg_buffers, output_buffers, temp_buffers.data(),
{{extern_fn}}(arg_buffers, output_buffers, temp_buffer,
&events, /*unused=*/nullptr, /*continuation_point=*/0);
return GetFunctionTypeLayout().NativeLayoutResultToValue(result_buffer);
Expand All @@ -212,6 +213,8 @@ absl::StatusOr<::xls::Value> {{wrapper_fn_name}}({{wrapper_params}}) {
ResultLayoutSerialization(f, type_converter);
substitution_map["{{temp_buffer_size}}"] =
absl::StrCat(object_code.temp_buffer_size);
substitution_map["{{temp_buffer_align}}"] =
absl::StrCat(object_code.temp_buffer_alignment);

if (namespaces.empty()) {
substitution_map["{{open_ns}}"] = "";
Expand All @@ -234,7 +237,8 @@ absl::StatusOr<::xls::Value> {{wrapper_fn_name}}({{wrapper_params}}) {
params.push_back(absl::StrCat("const ::xls::Value& ", param->name()));
param_names.push_back(std::string(param->name()));
arg_buffer_decls.push_back(
absl::StrFormat(" uint8_t %s_buffer[%d];", param->name(),
absl::StrFormat(" alignas(%d) uint8_t %s_buffer[%d];",
object_code.parameter_alignments[i], param->name(),
object_code.parameter_buffer_sizes[i]));
arg_buffer_names.push_back(absl::StrCat(param->name(), "_buffer"));
}
Expand All @@ -245,6 +249,8 @@ absl::StatusOr<::xls::Value> {{wrapper_fn_name}}({{wrapper_params}}) {
substitution_map["{{arg_buffer_collector}}"] =
absl::StrFormat("{%s}", absl::StrJoin(arg_buffer_names, ", "));
substitution_map["{{result_size}}"] = absl::StrCat(return_type_bytes);
substitution_map["{{result_buffer_align}}"] =
absl::StrCat(object_code.return_buffer_alignment);
substitution_map["{{arg_count}}"] = absl::StrCat(params.size());

std::string type_textproto;
Expand Down
18 changes: 10 additions & 8 deletions xls/jit/function_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ absl::StatusOr<JitObjectCode> FunctionJit::CreateObjectCode(
.parameter_buffer_sizes = std::vector<int64_t>(
jit->jitted_function_base_.input_buffer_sizes().cbegin(),
jit->jitted_function_base_.input_buffer_sizes().cend()),
.parameter_alignments = std::vector<int64_t>(
jit->jitted_function_base_.input_buffer_abi_alignments().cbegin(),
jit->jitted_function_base_.input_buffer_abi_alignments().cend()),
.return_buffer_size = jit->jitted_function_base_.output_buffer_sizes()[0],
.return_buffer_alignment =
jit->jitted_function_base_.output_buffer_abi_alignments()[0],
.temp_buffer_size = jit->GetTempBufferSize(),
.temp_buffer_alignment = jit->GetTempBufferAlignment(),
};
}

Expand Down Expand Up @@ -147,12 +153,10 @@ absl::Status FunctionJit::RunWithViews(absl::Span<uint8_t* const> args,
return absl::OkStatus();
}

template
absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/true>(
template absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/true>(
absl::Span<uint8_t* const> args, absl::Span<uint8_t> result_buffer,
InterpreterEvents* events);
template
absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/false>(
template absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/false>(
absl::Span<uint8_t* const> args, absl::Span<uint8_t> result_buffer,
InterpreterEvents* events);

Expand All @@ -166,12 +170,10 @@ void FunctionJit::InvokeUnalignedJitFunction(
/*user_data=*/nullptr, runtime(), /*continuation_point=*/0);
}

template
void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/false>(
template void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/false>(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events);
template
void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/true>(
template void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/true>(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events);

Expand Down
7 changes: 7 additions & 0 deletions xls/jit/function_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ struct JitObjectCode {

// Size of the buffers for the parameters and result.
std::vector<int64_t> parameter_buffer_sizes;
std::vector<int64_t> parameter_alignments;
int64_t return_buffer_size;
int64_t return_buffer_alignment;

// Minimum size of the temporary buffer passed to the jitted function.
int64_t temp_buffer_size;
int64_t temp_buffer_alignment;
};

// This class provides a facility to execute XLS functions (on the host) by
Expand Down Expand Up @@ -180,6 +183,10 @@ class FunctionJit {
return jitted_function_base_.temp_buffer_size();
}

int64_t GetTempBufferAlignment() const {
return jitted_function_base_.temp_buffer_alignment();
}

// Returns the name of the jitted function.
std::string_view GetJittedFunctionName() const {
return jitted_function_base_.function_name();
Expand Down

0 comments on commit d361acb

Please sign in to comment.