-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[super ugly maybe working code] use shim.h instead of Tensor #1548
Draft
janeyx99
wants to merge
12
commits into
main
Choose a base branch
from
try-aoti-shim
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
38b50b4
[super ugly not working code] use shim.h instead of Tensor
janeyx99 4d0cebf
Cleaned up PoC
janeyx99 abdae1e
Ignore ignore_this
janeyx99 df85d15
add mock registration prototype
janeyx99 4120a8b
there is a diff between IValue blah and IValue& blah
janeyx99 a1944ee
Clean up code, finish other end of void* boxed kernel
janeyx99 672aeec
[skip ci] saving work on registration
janeyx99 97a9220
[skip ci] This definitely does not compile
janeyx99 31c2925
Now the code compiles
janeyx99 a0500e5
Move commented out notes to the bottom to not distract
janeyx99 5e2c2d0
Remove dependency on change in core
janeyx99 cc4d022
Fix memory leak by using RAIIATH
janeyx99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
274 changes: 274 additions & 0 deletions
274
torchao/csrc/cuda/tensor_core_tiled_layout/libtorch.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
// in this file, we will implement the stuff in libtorch.h, | ||
// and we are allowed to call unstable stuff from pytorch! | ||
|
||
#include "libtorch.h" | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <ATen/DeviceGuard.h> | ||
#include <ATen/core/boxing/KernelFunction.h> | ||
#include <ATen/core/TensorAccessor.h> | ||
#include <ATen/core/ivalue.h> | ||
#include <ATen/core/stack.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <torch/csrc/inductor/aoti_runtime/utils.h> | ||
#include <torch/csrc/inductor/aoti_torch/utils.h> | ||
#include <torch/library.h> | ||
|
||
class StableLibrary::TorchLibraryOpaque { | ||
public: | ||
// TODO: support other Kinds lol, you'll need to translate between StableLibrary::Kind and Library::Kind | ||
TorchLibraryOpaque(StableLibrary::Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line) | ||
: library_(torch::Library::Kind::IMPL, std::move(ns), k, file, line) {} | ||
|
||
TorchLibraryOpaque(const TorchLibraryOpaque&) = delete; | ||
TorchLibraryOpaque& operator=(const TorchLibraryOpaque&) = delete; | ||
TorchLibraryOpaque(TorchLibraryOpaque&&) = default; | ||
TorchLibraryOpaque& operator=(TorchLibraryOpaque&&) = default; | ||
~TorchLibraryOpaque() = default; | ||
|
||
void impl(const char* name, torch::CppFunction fn) { | ||
library_.impl(name, std::move(fn)); | ||
} | ||
private: | ||
torch::Library library_; // Actual Library object | ||
}; | ||
|
||
|
||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; | ||
|
||
class VoidStarConverter: public c10::OperatorKernel { | ||
public: | ||
VoidStarConverter(void (*fn)(void **, int64_t, int64_t)) : fn_(fn) {} | ||
|
||
void operator()(const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) { | ||
const auto& schema = op.schema(); | ||
const auto num_returns = schema.returns().size(); | ||
const auto num_arguments = schema.arguments().size(); | ||
// to make this faster, you can make this a C array on the stack --> though this may cause a stackoverflow | ||
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *)); | ||
// std::unique_ptr<void *[]> ministack = std::make_unique<void*[]>(num_arguments + num_returns); | ||
|
||
for (size_t idx = 0; idx < num_arguments; idx++) { // rbarnes will prefer a c10::irange instead of this loop! | ||
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments); | ||
if (arg.isInt()) { | ||
ministack[idx] = reinterpret_cast<void *>(arg.toInt()); | ||
} else if (arg.isTensor()) { | ||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(std::move(const_cast<at::Tensor&>(arg.toTensor()))); | ||
ministack[idx] = reinterpret_cast<void *>(ath); | ||
} else { | ||
TORCH_CHECK(false, "Other types of IValues not yet handled!"); | ||
} | ||
} | ||
|
||
// second function is going to take a stack of void*, cast them to our | ||
// schema values for now, and run the function and modify the void* stack | ||
fn_(ministack, num_arguments, num_returns); | ||
|
||
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope | ||
// before calling the function | ||
torch::jit::drop(stack, num_arguments); | ||
|
||
// read the output from the end of the stack and wrap that back into | ||
// IValue from void*? | ||
for (size_t idx = 0; idx < num_returns; idx++) { | ||
const c10::TypePtr& ret_type = schema.returns()[idx].type(); | ||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) { | ||
auto ret_raiiath = RAIIATH(reinterpret_cast<AtenTensorHandle>(ministack[num_arguments + idx])); | ||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_raiiath.get()); | ||
torch::jit::push(stack, c10::IValue(out)); | ||
} else { | ||
TORCH_CHECK(false, "Only Tensor return types are currently supported!"); | ||
} | ||
} | ||
|
||
free(ministack); | ||
} | ||
|
||
private: | ||
void (*fn_)(void **, int64_t, int64_t); | ||
}; | ||
|
||
|
||
StableLibrary::StableLibrary(StableLibrary::Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line) | ||
: lib_(new TorchLibraryOpaque(StableLibrary::Kind::IMPL, std::move(ns), k, file, line)) {} | ||
|
||
|
||
StableLibrary& StableLibrary::impl(const char* name, void (*fn)(void **, int64_t, int64_t)) { | ||
this->lib_->impl(name, torch::CppFunction::makeFromBoxedFunctor(std::move(std::make_unique<VoidStarConverter>(fn)))); | ||
return *this; | ||
} | ||
|
||
|
||
|
||
|
||
// notes from trying to understand stuff + iteration | ||
/** | ||
// step 1: from here, call the ATH func | ||
// step 2: make ATH func also boxed and call it | ||
// step 3: move abstract code to libtorch | ||
void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op, | ||
torch::jit::Stack *stack) { | ||
|
||
// function pt1 here should take in IValues, pass a malloc'd stack into the | ||
// second function | ||
// need a translation from IValues to ATH to void*s! | ||
|
||
const auto& schema = op.schema(); | ||
const auto num_returns = schema.returns().size(); | ||
const auto num_arguments = schema.arguments().size(); | ||
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *)); | ||
|
||
for (auto idx = 0; idx < num_arguments; idx++) { | ||
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments); | ||
if (arg.isInt()) { | ||
ministack[idx] = reinterpret_cast<void *>(arg.toInt()); | ||
} else if (arg.isTensor()) { | ||
const at::Tensor& tensor = arg.toTensor(); | ||
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor); | ||
ministack[idx] = reinterpret_cast<void *>(ath); | ||
} else { | ||
TORCH_CHECK(false, "Other types of IValues not yet handled!"); | ||
} | ||
} | ||
|
||
// second function is going to take a stack of void*, cast them to our | ||
// schema values for now, and run the function and modify the void* stack | ||
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_arguments, | ||
num_returns); | ||
|
||
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope | ||
// before calling the function | ||
torch::jit::drop(stack, num_arguments); | ||
|
||
// read the output from the end of the stack and wrap that back into | ||
// IValue from void*? | ||
for (auto idx = 0; idx < num_returns; idx++) { | ||
const c10::TypePtr& ret_type = schema.returns()[idx].type(); | ||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) { | ||
AtenTensorHandle ret_ath = reinterpret_cast<AtenTensorHandle>( ministack[num_arguments + idx]); | ||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath); | ||
torch::jit::push(stack, c10::IValue(out)); | ||
} else { | ||
TORCH_CHECK(false, "Only Tensor return types are currently supported!"); | ||
} | ||
} | ||
|
||
free(ministack); | ||
} | ||
|
||
|
||
void boxed_unpack_tensor_core_tiled_layout(const c10::OperatorHandle &op, | ||
torch::jit::Stack *stack) { | ||
|
||
// function pt1 here should take in IValues, pass a malloc'd stack into the | ||
// second function | ||
// need a translation from IValues to ATH to void*s! | ||
|
||
const auto& schema = op.schema(); | ||
const auto num_returns = schema.returns().size(); | ||
const auto num_arguments = schema.arguments().size(); | ||
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *)); | ||
|
||
for (auto idx = 0; idx < num_arguments; idx++) { | ||
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments); | ||
if (arg.isInt()) { | ||
ministack[idx] = reinterpret_cast<void *>(arg.toInt()); | ||
} else if (arg.isTensor()) { | ||
const at::Tensor& tensor = arg.toTensor(); | ||
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor); | ||
ministack[idx] = reinterpret_cast<void *>(ath); | ||
} else { | ||
TORCH_CHECK(false, "Other types of IValues not yet handled!"); | ||
} | ||
} | ||
|
||
// second function is going to take a stack of void*, cast them to our | ||
// schema values for now, and run the function and modify the void* stack | ||
voidyvoid_boxed_ATH_unpack_tensor_core_tiled_layout(ministack, num_arguments, | ||
num_returns); | ||
|
||
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope | ||
// before calling the function | ||
torch::jit::drop(stack, num_arguments); | ||
|
||
// read the output from the end of the stack and wrap that back into | ||
// IValue from void*? | ||
for (auto idx = 0; idx < num_returns; idx++) { | ||
const c10::TypePtr& ret_type = schema.returns()[idx].type(); | ||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) { | ||
AtenTensorHandle ret_ath = reinterpret_cast<AtenTensorHandle>( ministack[num_arguments + idx]); | ||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath); | ||
torch::jit::push(stack, c10::IValue(out)); | ||
} else { | ||
TORCH_CHECK(false, "Only Tensor return types are currently supported!"); | ||
} | ||
} | ||
|
||
free(ministack); | ||
} | ||
|
||
void boxed_void_function(const c10::OperatorHandle &op, torch::jit::Stack *stack) { | ||
|
||
// function pt1 here should take in IValues, pass a malloc'd stack into the | ||
// second function | ||
// need a translation from IValues to ATH to void*s! | ||
|
||
const auto& schema = op.schema(); | ||
const auto num_returns = schema.returns().size(); | ||
const auto num_arguments = schema.arguments().size(); | ||
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *)); | ||
|
||
for (auto idx = 0; idx < num_arguments; idx++) { | ||
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments); | ||
if (arg.isInt()) { | ||
ministack[idx] = reinterpret_cast<void *>(arg.toInt()); | ||
} else if (arg.isTensor()) { | ||
const at::Tensor& tensor = arg.toTensor(); | ||
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor); | ||
ministack[idx] = reinterpret_cast<void *>(ath); | ||
} else { | ||
TORCH_CHECK(false, "Other types of IValues not yet handled!"); | ||
} | ||
} | ||
|
||
// second function is going to take a stack of void*, cast them to our | ||
// schema values for now, and run the function and modify the void* stack | ||
voidyvoid_boxed_ATH_unpack_tensor_core_tiled_layout(ministack, num_arguments, | ||
num_returns); | ||
|
||
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope | ||
// before calling the function | ||
torch::jit::drop(stack, num_arguments); | ||
|
||
// read the output from the end of the stack and wrap that back into | ||
// IValue from void*? | ||
for (auto idx = 0; idx < num_returns; idx++) { | ||
const c10::TypePtr& ret_type = schema.returns()[idx].type(); | ||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) { | ||
AtenTensorHandle ret_ath = reinterpret_cast<AtenTensorHandle>( ministack[num_arguments + idx]); | ||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath); | ||
torch::jit::push(stack, c10::IValue(out)); | ||
} else { | ||
TORCH_CHECK(false, "Only Tensor return types are currently supported!"); | ||
} | ||
} | ||
|
||
free(ministack); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(torchao, CUDA, m) { | ||
// m.impl("torchao::unpack_tensor_core_tiled_layout", | ||
// &_unpack_tensor_core_tiled_layout); | ||
m.impl("torchao::unpack_tensor_core_tiled_layout", | ||
torch::CppFunction::makeFromBoxedFunction< | ||
boxed_unpack_tensor_core_tiled_layout>()); | ||
// m.impl("torchao::dequantize_tensor_core_tiled_layout", | ||
// &_dequantize_tensor_core_tiled_layout); | ||
m.impl("torchao::dequantize_tensor_core_tiled_layout", | ||
torch::CppFunction::makeFromBoxedFunction< | ||
boxed_dequantize_tensor_core_tiled_layout>()); | ||
} | ||
|
||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// this file can only have stable stuff! Akin to shim.h | ||
|
||
#include <c10/util/BFloat16.h> | ||
#include <c10/macros/Macros.h> // used for C10_UID, verified to be header-only | ||
#include <c10/core/DispatchKey.h> // used for DispatchKey, enum verified to be header-only | ||
#include <torch/csrc/inductor/aoti_torch/c/shim.h> | ||
|
||
#include <optional> | ||
#include <string> | ||
|
||
class StableLibrary final { | ||
private: | ||
class TorchLibraryOpaque; | ||
using TorchLibraryHandle = TorchLibraryOpaque*; | ||
TorchLibraryHandle lib_; // pimpl unique_ptr | ||
public: | ||
// a pointer to a real Library | ||
// a kind | ||
enum Kind { | ||
// DEF, // from TORCH_LIBRARY (no qualifier) | ||
IMPL, | ||
// FRAGMENT, | ||
}; | ||
|
||
// constructor | ||
/// \private | ||
/// | ||
/// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using these | ||
/// constructors directly | ||
StableLibrary( | ||
Kind kind, | ||
std::string ns, | ||
std::optional<c10::DispatchKey> k, | ||
const char* file, | ||
uint32_t line); | ||
|
||
StableLibrary(const StableLibrary&) = delete; | ||
StableLibrary& operator=(const StableLibrary&) = delete; | ||
StableLibrary(StableLibrary&&) = default; | ||
StableLibrary& operator=(StableLibrary&&) = default; | ||
~StableLibrary() = default; | ||
|
||
StableLibrary& impl(const char* name, void (*fn)(void **, int64_t, int64_t)); | ||
}; | ||
|
||
class StableTorchLibraryInit final { | ||
private: | ||
using InitFn = void(StableLibrary&); | ||
StableLibrary lib_; | ||
|
||
public: | ||
StableTorchLibraryInit( | ||
StableLibrary::Kind kind, | ||
InitFn* fn, | ||
const char* ns, | ||
std::optional<c10::DispatchKey> k, | ||
const char* file, | ||
uint32_t line) | ||
: lib_(kind, ns, k, file, line) { | ||
fn(lib_); | ||
} | ||
}; | ||
|
||
|
||
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) | ||
|
||
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \ | ||
static void C10_CONCATENATE( \ | ||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \ | ||
static const StableTorchLibraryInit C10_CONCATENATE( \ | ||
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ | ||
StableLibrary::IMPL, \ | ||
&C10_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ | ||
#ns, \ | ||
std::make_optional(c10::DispatchKey::k), \ | ||
__FILE__, \ | ||
__LINE__); \ | ||
void C10_CONCATENATE( \ | ||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m) | ||
|
||
|
||
|
||
|
||
// notes while figuring out templating | ||
/** | ||
#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) | ||
|
||
#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ | ||
static void TORCH_LIBRARY_IMPL_init_torchao_CUDA_uid(torch::Library&); \ | ||
static const torch::detail::TorchLibraryInit \ | ||
TORCH_LIBRARY_IMPL_static_init_torchao_CUDA_uid( \ | ||
torch::Library::IMPL, \ | ||
(c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::CUDA) \ | ||
? &TORCH_LIBRARY_IMPL_init_torchao_CUDA_uid \ | ||
: [](torch::Library&) -> void {}), \ | ||
torchao, \ | ||
std::make_optional(c10::DispatchKey::CUDA), \ | ||
__FILE__, \ | ||
__LINE__); \ | ||
TORCH_LIBRARY_IMPL_init_torchao_CUDA_uid(torch::Library & m) { | ||
|
||
} | ||
*/ | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am guessing that, from design perspective, this is the interface layer with registration API that is shipped with different versions of libtorch? But the user code, for custom ops, never relies on at::Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this is to allow for registering libtorch-agnostic custom ops which cannot use at::Tensor or IValue, etc. in their schema