From 107faff7d620dc046d70e64c0c161a1194b498cb Mon Sep 17 00:00:00 2001 From: Alex Light Date: Thu, 14 Dec 2023 11:50:18 -0800 Subject: [PATCH] Support calling function-jit with unaligned buffers Add support for detecting and automatically memcpy-ing unaligned buffers in the function jit. This makes the function jit significantly easier to use. PiperOrigin-RevId: 591009156 --- xls/jit/function_base_jit.cc | 58 ++++++++++++++++++++++++++++++------ xls/jit/function_base_jit.h | 9 +++--- xls/jit/function_jit.cc | 25 ++++++++++++++-- xls/jit/function_jit.h | 36 +++++++++++++++++----- xls/jit/function_jit_test.cc | 42 ++++++++++++++++++++++---- 5 files changed, 141 insertions(+), 29 deletions(-) diff --git a/xls/jit/function_base_jit.cc b/xls/jit/function_base_jit.cc index 9bbc5513f7..4734c79f08 100644 --- a/xls/jit/function_base_jit.cc +++ b/xls/jit/function_base_jit.cc @@ -15,6 +15,8 @@ #include #include +#include +#include #include #include #include @@ -1375,28 +1377,66 @@ bool IsAligned(const void* ptr, int64_t align) { absl::Status VerifyOffsetAlignments(uint8_t const* const* const ptrs, absl::Span alignments) { for (int64_t i = 0; i < alignments.size(); ++i) { - XLS_RET_CHECK_EQ(absl::bit_cast(ptrs[i]) % alignments[i], 0) - << "value at index " << i << " is not aligned to " << alignments[i] - << ". Value is 0x" << std::hex << absl::bit_cast(ptrs[i]); + if (absl::bit_cast(ptrs[i]) % alignments[i] != 0) { + return absl::InvalidArgumentError( + absl::StrFormat("element %d of input vector does not have alignment " + "of %d. Pointer is %p", + i, alignments[i], ptrs[i])); + } } return absl::OkStatus(); } } // namespace +template int64_t JittedFunctionBase::RunUnalignedJittedFunction( const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, int64_t continuation) const { - // TODO(allight): Create an entry point to run these even if the arguments are - // not aligned correctly. - XLS_DCHECK_OK(VerifyOffsetAlignments(inputs, input_buffer_abi_alignments())); - XLS_DCHECK_OK( - VerifyOffsetAlignments(outputs, output_buffer_abi_alignments())); - XLS_DCHECK(IsAligned(temp_buffer, temp_buffer_alignment_)); + if constexpr (kForceZeroCopy) { + XLS_DCHECK_OK( + VerifyOffsetAlignments(inputs, input_buffer_abi_alignments())); + XLS_DCHECK_OK( + VerifyOffsetAlignments(outputs, output_buffer_abi_alignments())); + XLS_DCHECK(IsAligned(temp_buffer, temp_buffer_alignment_)); + } else { + if (!VerifyOffsetAlignments(inputs, input_buffer_abi_alignments()).ok() || + !VerifyOffsetAlignments(outputs, output_buffer_abi_alignments()).ok() || + !IsAligned(temp_buffer, temp_buffer_alignment_)) { + JitArgumentSet aligned_input(CreateInputBuffer()); + JitArgumentSet aligned_output(CreateOutputBuffer()); + JitTempBuffer temp(CreateTempBuffer()); + memcpy(temp.get(), temp_buffer, temp_buffer_size_); + for (int i = 0; i < input_buffer_sizes().size(); ++i) { + memcpy(aligned_input.pointers()[i], inputs[i], input_buffer_sizes()[i]); + } + auto result = + RunJittedFunction(aligned_input, aligned_output, temp, events, + user_data, jit_runtime, continuation); + memcpy(temp_buffer, temp.get(), temp_buffer_size_); + for (int i = 0; i < output_buffer_sizes().size(); ++i) { + memcpy(outputs[i], aligned_output.pointers()[i], + output_buffer_sizes()[i]); + } + return result; + } + } return function_(inputs, outputs, temp_buffer, events, user_data, jit_runtime, continuation); } +template int64_t +JittedFunctionBase::RunUnalignedJittedFunction( + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, + int64_t continuation) const; + +template int64_t +JittedFunctionBase::RunUnalignedJittedFunction( + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, + int64_t continuation) const; + std::optional JittedFunctionBase::RunPackedJittedFunction( const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, diff --git a/xls/jit/function_base_jit.h b/xls/jit/function_base_jit.h index 22d47d0f9d..28aa22990e 100644 --- a/xls/jit/function_base_jit.h +++ b/xls/jit/function_base_jit.h @@ -114,11 +114,10 @@ class JittedFunctionBase { JitRuntime* jit_runtime, int64_t continuation_point) const; - // Execute the jitted function using inputs not created by this function - // (after verifying some invariants). - // - // TODO(allight): 2023-12-05: We should have a shim that ensures everythings - // aligned ideally. + // Execute the jitted function using inputs not created by this function. + // If kForceZeroCopy is false the inputs will be memcpy'd if needed to aligned + // temporary buffers. + template int64_t RunUnalignedJittedFunction(const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, InterpreterEvents* events, void* user_data, diff --git a/xls/jit/function_jit.cc b/xls/jit/function_jit.cc index 504d23a9da..2e50e4bc17 100644 --- a/xls/jit/function_jit.cc +++ b/xls/jit/function_jit.cc @@ -125,6 +125,7 @@ absl::StatusOr> FunctionJit::Run( return Run(positional_args); } +template absl::Status FunctionJit::RunWithViews(absl::Span args, absl::Span result_buffer, InterpreterEvents* events) { @@ -141,17 +142,37 @@ absl::Status FunctionJit::RunWithViews(absl::Span args, GetReturnTypeSize())); } - InvokeUnalignedJitFunction(args, result_buffer.data(), events); + InvokeUnalignedJitFunction(args, result_buffer.data(), + events); return absl::OkStatus(); } +template +absl::Status FunctionJit::RunWithViews( + absl::Span args, absl::Span result_buffer, + InterpreterEvents* events); +template +absl::Status FunctionJit::RunWithViews( + absl::Span args, absl::Span result_buffer, + InterpreterEvents* events); + +template void FunctionJit::InvokeUnalignedJitFunction( absl::Span arg_buffers, uint8_t* output_buffer, InterpreterEvents* events) { uint8_t* output_buffers[1] = {output_buffer}; - jitted_function_base_.RunUnalignedJittedFunction( + jitted_function_base_.RunUnalignedJittedFunction( arg_buffers.data(), output_buffers, temp_buffer_.get(), events, /*user_data=*/nullptr, runtime(), /*continuation_point=*/0); } +template +void FunctionJit::InvokeUnalignedJitFunction( + absl::Span arg_buffers, uint8_t* output_buffer, + InterpreterEvents* events); +template +void FunctionJit::InvokeUnalignedJitFunction( + absl::Span arg_buffers, uint8_t* output_buffer, + InterpreterEvents* events); + } // namespace xls diff --git a/xls/jit/function_jit.h b/xls/jit/function_jit.h index b9f4ad0689..f0e4b7d07c 100644 --- a/xls/jit/function_jit.h +++ b/xls/jit/function_jit.h @@ -87,6 +87,7 @@ class FunctionJit { // TODO(https://github.com/google/xls/issues/506): 2021-10-13 Figure out // if we want a way to return events in the view and packed view interfaces // (or if their performance-focused execution means events are unimportant). + template absl::Status RunWithViews(absl::Span args, absl::Span result_buffer, InterpreterEvents* events); @@ -130,15 +131,19 @@ class FunctionJit { // Same as RunWithPackedViews but expects a View rather than a PackedView. template absl::Status RunWithUnpackedViews(ArgsT... args) { - const uint8_t* arg_buffers[sizeof...(ArgsT)]; - uint8_t* result_buffer; - - // Walk the type tree to get each arg's data buffer into our view/arg list. - PackArgBuffers(arg_buffers, &result_buffer, args...); + return RunWithUnpackedViewsCommon( + args...); + } - InterpreterEvents events; - InvokeUnalignedJitFunction(arg_buffers, result_buffer, &events); - return InterpreterEventsToStatus(events); + // Same as RunWithPackedViews but expects a View rather than a PackedView. + // Guaranteed to run without copying the arguments first. The arguments must + // be aligned correctly. + // NOTE: Alignment is determined by LLVM and might change with little warning. + // TODO(allight): 2023-12-6 We need to make this more usable safely. + template + absl::Status RunWithUnpackedViewsZeroCopy(ArgsT... args) { + return RunWithUnpackedViewsCommon( + args...); } // Returns the function that the JIT executes. @@ -198,6 +203,20 @@ class FunctionJit { Function* xls_function, int64_t opt_level, bool emit_object_code, JitObserver* observer); + template + absl::Status RunWithUnpackedViewsCommon(ArgsT... args) { + const uint8_t* arg_buffers[sizeof...(ArgsT)]; + uint8_t* result_buffer; + + // Walk the type tree to get each arg's data buffer into our view/arg list. + PackArgBuffers(arg_buffers, &result_buffer, args...); + + InterpreterEvents events; + InvokeUnalignedJitFunction(arg_buffers, result_buffer, + &events); + return InterpreterEventsToStatus(events); + } + // Builds a function which wraps the natively compiled XLS function `callee` // (as built by xls::BuildFunction) with another function which accepts the // input arguments as an array of pointers to buffers and the output as a @@ -236,6 +255,7 @@ class FunctionJit { } // Invokes the jitted function with the given argument and outputs. + template void InvokeUnalignedJitFunction(absl::Span arg_buffers, uint8_t* output_buffer, InterpreterEvents* events); diff --git a/xls/jit/function_jit_test.cc b/xls/jit/function_jit_test.cc index 457e83431b..466425e4d1 100644 --- a/xls/jit/function_jit_test.cc +++ b/xls/jit/function_jit_test.cc @@ -933,7 +933,39 @@ fn f(x: bits[1], y: bits[8]) -> (bits[1], bits[8], bits[16]) { } } -// TODO(allight): 2023-12-08: This should be supported. +TEST(FunctionJitTest, MisalignedPointerCopied) { + Package package("my_package"); + + FunctionBuilder fb("test", &package); + fb.Add(fb.Param("x", package.GetBitsType(256)), + fb.Param("y", package.GetBitsType(256))); + XLS_ASSERT_OK_AND_ASSIGN(Function * function, fb.Build()); + XLS_ASSERT_OK_AND_ASSIGN(auto jit, FunctionJit::Create(function)); + Bits ret_bits = + bits_ops::Concat({UBits(0, 256 - 65), UBits(1, 1), UBits(0, 64)}); + + alignas(16) std::array x_view{ + 0xAB, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + alignas(16) std::array y_view{ + 0xAB, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + alignas(16) std::array ret_view{}; + { + InterpreterEvents events; + EXPECT_THAT(jit->RunWithViews( + {x_view.data() + 1, y_view.data() + 1}, + absl::MakeSpan(ret_view).subspan(1), &events), + status_testing::IsOk()); + EXPECT_EQ(Bits::FromBytes(absl::MakeSpan(ret_view).subspan(1), 256), + ret_bits); + } +} + TEST(FunctionJitDeathTest, MisalignedPointerCaught) { #ifndef NDEBUG Package package("my_package"); @@ -958,11 +990,11 @@ TEST(FunctionJitDeathTest, MisalignedPointerCaught) { ASSERT_DEATH( { InterpreterEvents events; - auto unused = - jit->RunWithViews({x_view.data() + 1, y_view.data() + 1}, - absl::MakeSpan(ret_view).subspan(1), &events); + auto unused = jit->RunWithViews( + {x_view.data() + 1, y_view.data() + 1}, + absl::MakeSpan(ret_view).subspan(1), &events); }, - ".*is not aligned to [0-9]+.*"); + ".*does not have alignment of [0-9]+.*"); #else GTEST_SKIP() << "Checking only performed in dbg mode."; #endif