Skip to content

Commit 9f816b5

Browse files
committed
there is a diff between IValue blah and IValue& blah
1 parent b469f12 commit 9f816b5

File tree

1 file changed

+36
-106
lines changed

1 file changed

+36
-106
lines changed

torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu

+36-106
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,10 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
358358
// here, void* is my StableIValue
359359
// function is going to take a stack of void*, cast them to our
360360
// schema values for now, and run the function and modify the void* stack
361-
int64_t innerKTiles = *reinterpret_cast<int64_t *>(stack[3]);
362-
int64_t group_size = *reinterpret_cast<int64_t *>(stack[2]);
361+
int64_t innerKTiles = reinterpret_cast<int64_t>(stack[3]);
362+
int64_t group_size = reinterpret_cast<int64_t>(stack[2]);
363+
TORCH_WARN(innerKTiles);
364+
TORCH_WARN(group_size);
363365
AtenTensorHandle scales_and_zeros_ath =
364366
reinterpret_cast<AtenTensorHandle>(stack[1]);
365367
AtenTensorHandle packed_w_ath = reinterpret_cast<AtenTensorHandle>(stack[0]);
@@ -380,37 +382,49 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
380382
// function pt1 here should take in IValues, pass a malloc'd stack into the
381383
// second function
382384
// need a translation from IValues to ATH to void*s!
383-
int64_t innerKTiles = torch::jit::pop(stack).toInt();
384-
int64_t group_size = torch::jit::pop(stack).toInt();
385-
const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
386-
AtenTensorHandle scales_and_zeros_ath =
387-
torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
388-
const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
389-
AtenTensorHandle packed_w_ath =
390-
torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
391-
392-
int64_t num_args = 4;
393-
int64_t num_outputs = 1;
394-
void **ministack = (void**)malloc((num_args + num_outputs) * sizeof(void *));
395-
ministack[3] = reinterpret_cast<void *>(&innerKTiles);
396-
ministack[2] = reinterpret_cast<void *>(&group_size);
397-
ministack[1] = reinterpret_cast<void *>(scales_and_zeros_ath);
398-
ministack[0] = reinterpret_cast<void *>(packed_w_ath);
385+
386+
const auto& schema = op.schema();
387+
const auto num_returns = schema.returns().size();
388+
const auto num_arguments = schema.arguments().size();
389+
TORCH_CHECK(num_arguments==4);
390+
TORCH_CHECK(num_returns==1);
391+
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *));
392+
393+
for (auto idx = 0; idx < num_arguments; idx++) {
394+
TORCH_WARN(idx);
395+
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments);
396+
if (arg.isInt()) {
397+
ministack[idx] = reinterpret_cast<void *>(arg.toInt());
398+
} else if (arg.isTensor()) {
399+
TORCH_WARN("am tensor!")
400+
const at::Tensor& tensor = arg.toTensor();
401+
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor);
402+
ministack[idx] = reinterpret_cast<void *>(ath);
403+
} else {
404+
TORCH_CHECK(false, "Other types of IValues not handled!");
405+
}
406+
}
407+
TORCH_WARN("done with forloop no problems!")
399408

400409
// second function is going to take a stack of void*, cast them to our
401410
// schema values for now, and run the function and modify the void* stack
402-
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_args,
403-
num_outputs);
411+
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_arguments,
412+
num_returns);
404413

405414
// now read the output from the end of the stack and wrap that back into
406415
// IValue from void*?
407416

408417
AtenTensorHandle out_ath =
409-
reinterpret_cast<AtenTensorHandle>(ministack[num_args]);
410-
418+
reinterpret_cast<AtenTensorHandle>(ministack[num_arguments]);
419+
411420
free(ministack);
421+
412422
at::Tensor out =
413423
*torch::aot_inductor::tensor_handle_to_tensor_pointer(out_ath);
424+
425+
// now pop everything. if we pop earlier, Tensors would go out of scope
426+
// before calling the function
427+
torch::jit::drop(stack, num_arguments);
414428
torch::jit::push(stack, c10::IValue(out));
415429

416430
// so above is our stack of IValues, but we cannot have these IValues because
@@ -488,90 +502,6 @@ at::Tensor _unpack_tensor_core_tiled_layout(const at::Tensor &packed_w,
488502
return out;
489503
}
490504

491-
void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
492-
int64_t num_args,
493-
int64_t num_outputs) {
494-
// here, void* is my StableIValue
495-
// function is going to take a stack of void*, cast them to our
496-
// schema values for now, and run the function and modify the void* stack
497-
int64_t innerKTiles = *reinterpret_cast<int64_t *>(stack[3]);
498-
int64_t group_size = *reinterpret_cast<int64_t *>(stack[2]);
499-
AtenTensorHandle scales_and_zeros_ath =
500-
reinterpret_cast<AtenTensorHandle>(stack[1]);
501-
AtenTensorHandle packed_w_ath = reinterpret_cast<AtenTensorHandle>(stack[0]);
502-
503-
AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
504-
packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
505-
506-
void *out = reinterpret_cast<void *>(ath_res);
507-
stack[num_args] = out;
508-
}
509-
510-
// step 1: from here, call the ATH func
511-
// step 2: make ATH func also boxed and call it
512-
// step 3: move abstract code to libtorch
513-
void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
514-
torch::jit::Stack *stack) {
515-
516-
// function pt1 here should take in IValues, pass a malloc'd stack into the
517-
// second function
518-
// need a translation from IValues to ATH to void*s!
519-
int64_t innerKTiles = torch::jit::pop(stack).toInt();
520-
int64_t group_size = torch::jit::pop(stack).toInt();
521-
const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
522-
AtenTensorHandle scales_and_zeros_ath =
523-
torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
524-
const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
525-
AtenTensorHandle packed_w_ath =
526-
torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
527-
528-
int64_t num_args = 4;
529-
int64_t num_outputs = 1;
530-
void **ministack = (void**)malloc((num_args + num_outputs) * sizeof(void *));
531-
ministack[3] = reinterpret_cast<void *>(&innerKTiles);
532-
ministack[2] = reinterpret_cast<void *>(&group_size);
533-
ministack[1] = reinterpret_cast<void *>(scales_and_zeros_ath);
534-
ministack[0] = reinterpret_cast<void *>(packed_w_ath);
535-
536-
// second function is going to take a stack of void*, cast them to our
537-
// schema values for now, and run the function and modify the void* stack
538-
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_args,
539-
num_outputs);
540-
541-
// now read the output from the end of the stack and wrap that back into
542-
// IValue from void*?
543-
544-
AtenTensorHandle out_ath =
545-
reinterpret_cast<AtenTensorHandle>(ministack[num_args]);
546-
547-
free(ministack);
548-
at::Tensor out =
549-
*torch::aot_inductor::tensor_handle_to_tensor_pointer(out_ath);
550-
torch::jit::push(stack, c10::IValue(out));
551-
552-
// so above is our stack of IValues, but we cannot have these IValues because
553-
// they are NOT ABI stable! So we need another version of "boxed" with void*s.
554-
// and that is what is going to happen below
555-
556-
// what the old function used to be:
557-
// int64_t innerKTiles = torch::jit::pop(stack).toInt();
558-
// int64_t group_size = torch::jit::pop(stack).toInt();
559-
// const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
560-
// const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
561-
562-
// AtenTensorHandle packed_w_ath =
563-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
564-
// AtenTensorHandle scales_and_zeros_ath =
565-
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
566-
567-
// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
568-
// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
569-
570-
// at::Tensor out =
571-
// *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
572-
// torch::jit::push(stack, c10::IValue(out));
573-
}
574-
575505

576506
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
577507
m.impl("torchao::unpack_tensor_core_tiled_layout",

0 commit comments

Comments
 (0)