Skip to content

Commit

Permalink
Support calling function-jit with unaligned buffers
Browse files Browse the repository at this point in the history
Add support for detecting and automatically memcpy-ing unaligned buffers in the function jit. This makes the function jit significantly easier to use.

PiperOrigin-RevId: 591009156
  • Loading branch information
allight authored and copybara-github committed Dec 14, 2023
1 parent d601b7b commit 107faff
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 29 deletions.
58 changes: 49 additions & 9 deletions xls/jit/function_base_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <ios>
#include <iterator>
#include <limits>
Expand Down Expand Up @@ -1375,28 +1377,66 @@ bool IsAligned(const void* ptr, int64_t align) {
absl::Status VerifyOffsetAlignments(uint8_t const* const* const ptrs,
absl::Span<int64_t const> alignments) {
for (int64_t i = 0; i < alignments.size(); ++i) {
XLS_RET_CHECK_EQ(absl::bit_cast<uintptr_t>(ptrs[i]) % alignments[i], 0)
<< "value at index " << i << " is not aligned to " << alignments[i]
<< ". Value is 0x" << std::hex << absl::bit_cast<uintptr_t>(ptrs[i]);
if (absl::bit_cast<uintptr_t>(ptrs[i]) % alignments[i] != 0) {
return absl::InvalidArgumentError(
absl::StrFormat("element %d of input vector does not have alignment "
"of %d. Pointer is %p",
i, alignments[i], ptrs[i]));
}
}
return absl::OkStatus();
}
} // namespace

template <bool kForceZeroCopy>
int64_t JittedFunctionBase::RunUnalignedJittedFunction(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation) const {
// TODO(allight): Create an entry point to run these even if the arguments are
// not aligned correctly.
XLS_DCHECK_OK(VerifyOffsetAlignments(inputs, input_buffer_abi_alignments()));
XLS_DCHECK_OK(
VerifyOffsetAlignments(outputs, output_buffer_abi_alignments()));
XLS_DCHECK(IsAligned(temp_buffer, temp_buffer_alignment_));
if constexpr (kForceZeroCopy) {
XLS_DCHECK_OK(
VerifyOffsetAlignments(inputs, input_buffer_abi_alignments()));
XLS_DCHECK_OK(
VerifyOffsetAlignments(outputs, output_buffer_abi_alignments()));
XLS_DCHECK(IsAligned(temp_buffer, temp_buffer_alignment_));
} else {
if (!VerifyOffsetAlignments(inputs, input_buffer_abi_alignments()).ok() ||
!VerifyOffsetAlignments(outputs, output_buffer_abi_alignments()).ok() ||
!IsAligned(temp_buffer, temp_buffer_alignment_)) {
JitArgumentSet aligned_input(CreateInputBuffer());
JitArgumentSet aligned_output(CreateOutputBuffer());
JitTempBuffer temp(CreateTempBuffer());
memcpy(temp.get(), temp_buffer, temp_buffer_size_);
for (int i = 0; i < input_buffer_sizes().size(); ++i) {
memcpy(aligned_input.pointers()[i], inputs[i], input_buffer_sizes()[i]);
}
auto result =
RunJittedFunction(aligned_input, aligned_output, temp, events,
user_data, jit_runtime, continuation);
memcpy(temp_buffer, temp.get(), temp_buffer_size_);
for (int i = 0; i < output_buffer_sizes().size(); ++i) {
memcpy(outputs[i], aligned_output.pointers()[i],
output_buffer_sizes()[i]);
}
return result;
}
}
return function_(inputs, outputs, temp_buffer, events, user_data, jit_runtime,
continuation);
}

template int64_t
JittedFunctionBase::RunUnalignedJittedFunction</*kForceZeroCopy=*/true>(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation) const;

template int64_t
JittedFunctionBase::RunUnalignedJittedFunction</*kForceZeroCopy=*/false>(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation) const;

std::optional<int64_t> JittedFunctionBase::RunPackedJittedFunction(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
Expand Down
9 changes: 4 additions & 5 deletions xls/jit/function_base_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ class JittedFunctionBase {
JitRuntime* jit_runtime,
int64_t continuation_point) const;

// Execute the jitted function using inputs not created by this function
// (after verifying some invariants).
//
// TODO(allight): 2023-12-05: We should have a shim that ensures everythings
// aligned ideally.
// Execute the jitted function using inputs not created by this function.
// If kForceZeroCopy is false the inputs will be memcpy'd if needed to aligned
// temporary buffers.
template<bool kForceZeroCopy = false>
int64_t RunUnalignedJittedFunction(const uint8_t* const* inputs,
uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data,
Expand Down
25 changes: 23 additions & 2 deletions xls/jit/function_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ absl::StatusOr<InterpreterResult<Value>> FunctionJit::Run(
return Run(positional_args);
}

template <bool kForceZeroCopy>
absl::Status FunctionJit::RunWithViews(absl::Span<uint8_t* const> args,
absl::Span<uint8_t> result_buffer,
InterpreterEvents* events) {
Expand All @@ -141,17 +142,37 @@ absl::Status FunctionJit::RunWithViews(absl::Span<uint8_t* const> args,
GetReturnTypeSize()));
}

InvokeUnalignedJitFunction(args, result_buffer.data(), events);
InvokeUnalignedJitFunction<kForceZeroCopy>(args, result_buffer.data(),
events);
return absl::OkStatus();
}

template
absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/true>(
absl::Span<uint8_t* const> args, absl::Span<uint8_t> result_buffer,
InterpreterEvents* events);
template
absl::Status FunctionJit::RunWithViews</*kForceZeroCopy=*/false>(
absl::Span<uint8_t* const> args, absl::Span<uint8_t> result_buffer,
InterpreterEvents* events);

template <bool kForceZeroCopy>
void FunctionJit::InvokeUnalignedJitFunction(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events) {
uint8_t* output_buffers[1] = {output_buffer};
jitted_function_base_.RunUnalignedJittedFunction(
jitted_function_base_.RunUnalignedJittedFunction<kForceZeroCopy>(
arg_buffers.data(), output_buffers, temp_buffer_.get(), events,
/*user_data=*/nullptr, runtime(), /*continuation_point=*/0);
}

template
void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/false>(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events);
template
void FunctionJit::InvokeUnalignedJitFunction</*kForceZeroCopy=*/true>(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events);

} // namespace xls
36 changes: 28 additions & 8 deletions xls/jit/function_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class FunctionJit {
// TODO(https://github.com/google/xls/issues/506): 2021-10-13 Figure out
// if we want a way to return events in the view and packed view interfaces
// (or if their performance-focused execution means events are unimportant).
template <bool kForceZeroCopy = false>
absl::Status RunWithViews(absl::Span<uint8_t* const> args,
absl::Span<uint8_t> result_buffer,
InterpreterEvents* events);
Expand Down Expand Up @@ -130,15 +131,19 @@ class FunctionJit {
// Same as RunWithPackedViews but expects a View rather than a PackedView.
template <typename... ArgsT>
absl::Status RunWithUnpackedViews(ArgsT... args) {
const uint8_t* arg_buffers[sizeof...(ArgsT)];
uint8_t* result_buffer;

// Walk the type tree to get each arg's data buffer into our view/arg list.
PackArgBuffers(arg_buffers, &result_buffer, args...);
return RunWithUnpackedViewsCommon</*kForceZeroCopy=*/false, ArgsT...>(
args...);
}

InterpreterEvents events;
InvokeUnalignedJitFunction(arg_buffers, result_buffer, &events);
return InterpreterEventsToStatus(events);
// Same as RunWithPackedViews but expects a View rather than a PackedView.
// Guaranteed to run without copying the arguments first. The arguments must
// be aligned correctly.
// NOTE: Alignment is determined by LLVM and might change with little warning.
// TODO(allight): 2023-12-6 We need to make this more usable safely.
template <typename... ArgsT>
absl::Status RunWithUnpackedViewsZeroCopy(ArgsT... args) {
return RunWithUnpackedViewsCommon</*kForceZeroCopy=*/true, ArgsT...>(
args...);
}

// Returns the function that the JIT executes.
Expand Down Expand Up @@ -198,6 +203,20 @@ class FunctionJit {
Function* xls_function, int64_t opt_level, bool emit_object_code,
JitObserver* observer);

template <bool kForceZeroCopy, typename... ArgsT>
absl::Status RunWithUnpackedViewsCommon(ArgsT... args) {
const uint8_t* arg_buffers[sizeof...(ArgsT)];
uint8_t* result_buffer;

// Walk the type tree to get each arg's data buffer into our view/arg list.
PackArgBuffers(arg_buffers, &result_buffer, args...);

InterpreterEvents events;
InvokeUnalignedJitFunction<kForceZeroCopy>(arg_buffers, result_buffer,
&events);
return InterpreterEventsToStatus(events);
}

// Builds a function which wraps the natively compiled XLS function `callee`
// (as built by xls::BuildFunction) with another function which accepts the
// input arguments as an array of pointers to buffers and the output as a
Expand Down Expand Up @@ -236,6 +255,7 @@ class FunctionJit {
}

// Invokes the jitted function with the given argument and outputs.
template <bool kForceZeroCopy = false>
void InvokeUnalignedJitFunction(absl::Span<const uint8_t* const> arg_buffers,
uint8_t* output_buffer,
InterpreterEvents* events);
Expand Down
42 changes: 37 additions & 5 deletions xls/jit/function_jit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,39 @@ fn f(x: bits[1], y: bits[8]) -> (bits[1], bits[8], bits[16]) {
}
}

// TODO(allight): 2023-12-08: This should be supported.
TEST(FunctionJitTest, MisalignedPointerCopied) {
Package package("my_package");

FunctionBuilder fb("test", &package);
fb.Add(fb.Param("x", package.GetBitsType(256)),
fb.Param("y", package.GetBitsType(256)));
XLS_ASSERT_OK_AND_ASSIGN(Function * function, fb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, FunctionJit::Create(function));
Bits ret_bits =
bits_ops::Concat({UBits(0, 256 - 65), UBits(1, 1), UBits(0, 64)});

alignas(16) std::array<uint8_t, 1 + (256 / 8)> x_view{
0xAB, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};
alignas(16) std::array<uint8_t, 1 + (256 / 8)> y_view{
0xAB, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};
alignas(16) std::array<uint8_t, 1 + (256 / 8)> ret_view{};
{
InterpreterEvents events;
EXPECT_THAT(jit->RunWithViews</*kForceZeroCopy=*/false>(
{x_view.data() + 1, y_view.data() + 1},
absl::MakeSpan(ret_view).subspan(1), &events),
status_testing::IsOk());
EXPECT_EQ(Bits::FromBytes(absl::MakeSpan(ret_view).subspan(1), 256),
ret_bits);
}
}

TEST(FunctionJitDeathTest, MisalignedPointerCaught) {
#ifndef NDEBUG
Package package("my_package");
Expand All @@ -958,11 +990,11 @@ TEST(FunctionJitDeathTest, MisalignedPointerCaught) {
ASSERT_DEATH(
{
InterpreterEvents events;
auto unused =
jit->RunWithViews({x_view.data() + 1, y_view.data() + 1},
absl::MakeSpan(ret_view).subspan(1), &events);
auto unused = jit->RunWithViews</*kForceZeroCopy=*/true>(
{x_view.data() + 1, y_view.data() + 1},
absl::MakeSpan(ret_view).subspan(1), &events);
},
".*is not aligned to [0-9]+.*");
".*does not have alignment of [0-9]+.*");
#else
GTEST_SKIP() << "Checking only performed in dbg mode.";
#endif
Expand Down

0 comments on commit 107faff

Please sign in to comment.