Skip to content

Commit 2017d7b

Browse files
committed
Clean up code, finish other end of void* boxed kernel
1 parent 9f816b5 commit 2017d7b

File tree

2 files changed

+19
-63
lines changed

2 files changed

+19
-63
lines changed

test/test_ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import sys
23

34
import pytest
45
import torch
@@ -614,4 +615,4 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
614615

615616

616617
if __name__ == "__main__":
617-
pytest.main([__file__])
618+
pytest.main(sys.argv)

torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu

+17-62
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
22

3-
// #include <ATen/ATen.h>
4-
// #include <ATen/core/Tensor.h>
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/Tensor.h>
55
#include <ATen/DeviceGuard.h>
66
#include <ATen/core/TensorAccessor.h>
77
#include <ATen/core/ivalue.h>
@@ -332,25 +332,6 @@ AtenTensorHandle _ATH_dequantize_tensor_core_tiled_layout(
332332
return out;
333333
}
334334

335-
// output is [n][k] (int32 dtype)
336-
// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
337-
// scales_and_zeros is [numQGroups][n][2]
338-
// qGroupSize is 32, 64, 128 or 256
339-
// at::Tensor
340-
// _dequantize_tensor_core_tiled_layout(const at::Tensor &packed_w,
341-
// const at::Tensor &scales_and_zeros,
342-
// int64_t group_size, int64_t innerKTiles) {
343-
344-
// AtenTensorHandle packed_w_ath =
345-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
346-
// AtenTensorHandle scales_and_zeros_ath =
347-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
348-
349-
// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
350-
// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
351-
352-
// return *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
353-
// }
354335

355336
void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
356337
int64_t num_args,
@@ -360,8 +341,6 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
360341
// schema values for now, and run the function and modify the void* stack
361342
int64_t innerKTiles = reinterpret_cast<int64_t>(stack[3]);
362343
int64_t group_size = reinterpret_cast<int64_t>(stack[2]);
363-
TORCH_WARN(innerKTiles);
364-
TORCH_WARN(group_size);
365344
AtenTensorHandle scales_and_zeros_ath =
366345
reinterpret_cast<AtenTensorHandle>(stack[1]);
367346
AtenTensorHandle packed_w_ath = reinterpret_cast<AtenTensorHandle>(stack[0]);
@@ -386,68 +365,44 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
386365
const auto& schema = op.schema();
387366
const auto num_returns = schema.returns().size();
388367
const auto num_arguments = schema.arguments().size();
389-
TORCH_CHECK(num_arguments==4);
390-
TORCH_CHECK(num_returns==1);
391368
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *));
392369

393370
for (auto idx = 0; idx < num_arguments; idx++) {
394-
TORCH_WARN(idx);
395371
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments);
396372
if (arg.isInt()) {
397373
ministack[idx] = reinterpret_cast<void *>(arg.toInt());
398374
} else if (arg.isTensor()) {
399-
TORCH_WARN("am tensor!")
400375
const at::Tensor& tensor = arg.toTensor();
401376
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor);
402377
ministack[idx] = reinterpret_cast<void *>(ath);
403378
} else {
404-
TORCH_CHECK(false, "Other types of IValues not handled!");
379+
TORCH_CHECK(false, "Other types of IValues not yet handled!");
405380
}
406381
}
407-
TORCH_WARN("done with forloop no problems!")
408382

409383
// second function is going to take a stack of void*, cast them to our
410384
// schema values for now, and run the function and modify the void* stack
411385
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_arguments,
412386
num_returns);
413387

414-
// now read the output from the end of the stack and wrap that back into
415-
// IValue from void*?
416-
417-
AtenTensorHandle out_ath =
418-
reinterpret_cast<AtenTensorHandle>(ministack[num_arguments]);
419-
420-
free(ministack);
421-
422-
at::Tensor out =
423-
*torch::aot_inductor::tensor_handle_to_tensor_pointer(out_ath);
424-
425-
// now pop everything. if we pop earlier, Tensors would go out of scope
388+
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope
426389
// before calling the function
427390
torch::jit::drop(stack, num_arguments);
428-
torch::jit::push(stack, c10::IValue(out));
429-
430-
// so above is our stack of IValues, but we cannot have these IValues because
431-
// they are NOT ABI stable! So we need another version of "boxed" with void*s.
432-
// and that is what is going to happen below
433-
434-
// what the old function used to be:
435-
// int64_t innerKTiles = torch::jit::pop(stack).toInt();
436-
// int64_t group_size = torch::jit::pop(stack).toInt();
437-
// const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
438-
// const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
439391

440-
// AtenTensorHandle packed_w_ath =
441-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
442-
// AtenTensorHandle scales_and_zeros_ath =
443-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
444-
445-
// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
446-
// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
392+
// read the output from the end of the stack and wrap that back into
393+
// IValue from void*?
394+
for (auto idx = 0; idx < num_returns; idx++) {
395+
const c10::TypePtr& ret_type = schema.returns()[idx].type();
396+
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
397+
AtenTensorHandle ret_ath = reinterpret_cast<AtenTensorHandle>( ministack[num_arguments + idx]);
398+
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath);
399+
torch::jit::push(stack, c10::IValue(out));
400+
} else {
401+
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
402+
}
403+
}
447404

448-
// at::Tensor out =
449-
// *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
450-
// torch::jit::push(stack, c10::IValue(out));
405+
free(ministack);
451406
}
452407

453408

0 commit comments

Comments
 (0)