@@ -358,8 +358,10 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
358
358
// here, void* is my StableIValue
359
359
// function is going to take a stack of void*, cast them to our
360
360
// schema values for now, and run the function and modify the void* stack
361
- int64_t innerKTiles = *reinterpret_cast <int64_t *>(stack[3 ]);
362
- int64_t group_size = *reinterpret_cast <int64_t *>(stack[2 ]);
361
+ int64_t innerKTiles = reinterpret_cast <int64_t >(stack[3 ]);
362
+ int64_t group_size = reinterpret_cast <int64_t >(stack[2 ]);
363
+ TORCH_WARN (innerKTiles);
364
+ TORCH_WARN (group_size);
363
365
AtenTensorHandle scales_and_zeros_ath =
364
366
reinterpret_cast <AtenTensorHandle>(stack[1 ]);
365
367
AtenTensorHandle packed_w_ath = reinterpret_cast <AtenTensorHandle>(stack[0 ]);
@@ -380,37 +382,49 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
380
382
// function pt1 here should take in IValues, pass a malloc'd stack into the
381
383
// second function
382
384
// need a translation from IValues to ATH to void*s!
383
- int64_t innerKTiles = torch::jit::pop (stack).toInt ();
384
- int64_t group_size = torch::jit::pop (stack).toInt ();
385
- const at::Tensor &scales_and_zeros = torch::jit::pop (stack).toTensor ();
386
- AtenTensorHandle scales_and_zeros_ath =
387
- torch::aot_inductor::tensor_pointer_to_tensor_handle (&scales_and_zeros);
388
- const at::Tensor &packed_w = torch::jit::pop (stack).toTensor ();
389
- AtenTensorHandle packed_w_ath =
390
- torch::aot_inductor::tensor_pointer_to_tensor_handle (&packed_w);
391
-
392
- int64_t num_args = 4 ;
393
- int64_t num_outputs = 1 ;
394
- void **ministack = (void **)malloc ((num_args + num_outputs) * sizeof (void *));
395
- ministack[3 ] = reinterpret_cast <void *>(&innerKTiles);
396
- ministack[2 ] = reinterpret_cast <void *>(&group_size);
397
- ministack[1 ] = reinterpret_cast <void *>(scales_and_zeros_ath);
398
- ministack[0 ] = reinterpret_cast <void *>(packed_w_ath);
385
+
386
+ const auto & schema = op.schema ();
387
+ const auto num_returns = schema.returns ().size ();
388
+ const auto num_arguments = schema.arguments ().size ();
389
+ TORCH_CHECK (num_arguments==4 );
390
+ TORCH_CHECK (num_returns==1 );
391
+ void **ministack = (void **)malloc ((num_arguments + num_returns) * sizeof (void *));
392
+
393
+ for (auto idx = 0 ; idx < num_arguments; idx++) {
394
+ TORCH_WARN (idx);
395
+ const c10::IValue& arg = torch::jit::peek (stack, idx, num_arguments);
396
+ if (arg.isInt ()) {
397
+ ministack[idx] = reinterpret_cast <void *>(arg.toInt ());
398
+ } else if (arg.isTensor ()) {
399
+ TORCH_WARN (" am tensor!" )
400
+ const at::Tensor& tensor = arg.toTensor ();
401
+ AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle (&tensor);
402
+ ministack[idx] = reinterpret_cast <void *>(ath);
403
+ } else {
404
+ TORCH_CHECK (false , " Other types of IValues not handled!" );
405
+ }
406
+ }
407
+ TORCH_WARN (" done with forloop no problems!" )
399
408
400
409
// second function is going to take a stack of void*, cast them to our
401
410
// schema values for now, and run the function and modify the void* stack
402
- voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (ministack, num_args ,
403
- num_outputs );
411
+ voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (ministack, num_arguments ,
412
+ num_returns );
404
413
405
414
// now read the output from the end of the stack and wrap that back into
406
415
// IValue from void*?
407
416
408
417
AtenTensorHandle out_ath =
409
- reinterpret_cast <AtenTensorHandle>(ministack[num_args ]);
410
-
418
+ reinterpret_cast <AtenTensorHandle>(ministack[num_arguments ]);
419
+
411
420
free (ministack);
421
+
412
422
at::Tensor out =
413
423
*torch::aot_inductor::tensor_handle_to_tensor_pointer (out_ath);
424
+
425
+ // now pop everything. if we pop earlier, Tensors would go out of scope
426
+ // before calling the function
427
+ torch::jit::drop (stack, num_arguments);
414
428
torch::jit::push (stack, c10::IValue (out));
415
429
416
430
// so above is our stack of IValues, but we cannot have these IValues because
@@ -488,90 +502,6 @@ at::Tensor _unpack_tensor_core_tiled_layout(const at::Tensor &packed_w,
488
502
return out;
489
503
}
490
504
491
- void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (void **stack,
492
- int64_t num_args,
493
- int64_t num_outputs) {
494
- // here, void* is my StableIValue
495
- // function is going to take a stack of void*, cast them to our
496
- // schema values for now, and run the function and modify the void* stack
497
- int64_t innerKTiles = *reinterpret_cast <int64_t *>(stack[3 ]);
498
- int64_t group_size = *reinterpret_cast <int64_t *>(stack[2 ]);
499
- AtenTensorHandle scales_and_zeros_ath =
500
- reinterpret_cast <AtenTensorHandle>(stack[1 ]);
501
- AtenTensorHandle packed_w_ath = reinterpret_cast <AtenTensorHandle>(stack[0 ]);
502
-
503
- AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout (
504
- packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
505
-
506
- void *out = reinterpret_cast <void *>(ath_res);
507
- stack[num_args] = out;
508
- }
509
-
510
- // step 1: from here, call the ATH func
511
- // step 2: make ATH func also boxed and call it
512
- // step 3: move abstract code to libtorch
513
- void boxed_dequantize_tensor_core_tiled_layout (const c10::OperatorHandle &op,
514
- torch::jit::Stack *stack) {
515
-
516
- // function pt1 here should take in IValues, pass a malloc'd stack into the
517
- // second function
518
- // need a translation from IValues to ATH to void*s!
519
- int64_t innerKTiles = torch::jit::pop (stack).toInt ();
520
- int64_t group_size = torch::jit::pop (stack).toInt ();
521
- const at::Tensor &scales_and_zeros = torch::jit::pop (stack).toTensor ();
522
- AtenTensorHandle scales_and_zeros_ath =
523
- torch::aot_inductor::tensor_pointer_to_tensor_handle (&scales_and_zeros);
524
- const at::Tensor &packed_w = torch::jit::pop (stack).toTensor ();
525
- AtenTensorHandle packed_w_ath =
526
- torch::aot_inductor::tensor_pointer_to_tensor_handle (&packed_w);
527
-
528
- int64_t num_args = 4 ;
529
- int64_t num_outputs = 1 ;
530
- void **ministack = (void **)malloc ((num_args + num_outputs) * sizeof (void *));
531
- ministack[3 ] = reinterpret_cast <void *>(&innerKTiles);
532
- ministack[2 ] = reinterpret_cast <void *>(&group_size);
533
- ministack[1 ] = reinterpret_cast <void *>(scales_and_zeros_ath);
534
- ministack[0 ] = reinterpret_cast <void *>(packed_w_ath);
535
-
536
- // second function is going to take a stack of void*, cast them to our
537
- // schema values for now, and run the function and modify the void* stack
538
- voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (ministack, num_args,
539
- num_outputs);
540
-
541
- // now read the output from the end of the stack and wrap that back into
542
- // IValue from void*?
543
-
544
- AtenTensorHandle out_ath =
545
- reinterpret_cast <AtenTensorHandle>(ministack[num_args]);
546
-
547
- free (ministack);
548
- at::Tensor out =
549
- *torch::aot_inductor::tensor_handle_to_tensor_pointer (out_ath);
550
- torch::jit::push (stack, c10::IValue (out));
551
-
552
- // so above is our stack of IValues, but we cannot have these IValues because
553
- // they are NOT ABI stable! So we need another version of "boxed" with void*s.
554
- // and that is what is going to happen below
555
-
556
- // what the old function used to be:
557
- // int64_t innerKTiles = torch::jit::pop(stack).toInt();
558
- // int64_t group_size = torch::jit::pop(stack).toInt();
559
- // const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
560
- // const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
561
-
562
- // AtenTensorHandle packed_w_ath =
563
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
564
- // AtenTensorHandle scales_and_zeros_ath =
565
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
566
-
567
- // AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
568
- // packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
569
-
570
- // at::Tensor out =
571
- // *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
572
- // torch::jit::push(stack, c10::IValue(out));
573
- }
574
-
575
505
576
506
TORCH_LIBRARY_IMPL (torchao, CUDA, m) {
577
507
m.impl (" torchao::unpack_tensor_core_tiled_layout" ,
0 commit comments