1
1
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
2
2
3
- // #include <ATen/ATen.h>
4
- // #include <ATen/core/Tensor.h>
3
+ #include < ATen/ATen.h>
4
+ #include < ATen/core/Tensor.h>
5
5
#include < ATen/DeviceGuard.h>
6
6
#include < ATen/core/TensorAccessor.h>
7
7
#include < ATen/core/ivalue.h>
@@ -332,25 +332,6 @@ AtenTensorHandle _ATH_dequantize_tensor_core_tiled_layout(
332
332
return out;
333
333
}
334
334
335
- // output is [n][k] (int32 dtype)
336
- // input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
337
- // scales_and_zeros is [numQGroups][n][2]
338
- // qGroupSize is 32, 64, 128 or 256
339
- // at::Tensor
340
- // _dequantize_tensor_core_tiled_layout(const at::Tensor &packed_w,
341
- // const at::Tensor &scales_and_zeros,
342
- // int64_t group_size, int64_t innerKTiles) {
343
-
344
- // AtenTensorHandle packed_w_ath =
345
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
346
- // AtenTensorHandle scales_and_zeros_ath =
347
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
348
-
349
- // AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
350
- // packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
351
-
352
- // return *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
353
- // }
354
335
355
336
void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (void **stack,
356
337
int64_t num_args,
@@ -360,8 +341,6 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
360
341
// schema values for now, and run the function and modify the void* stack
361
342
int64_t innerKTiles = reinterpret_cast <int64_t >(stack[3 ]);
362
343
int64_t group_size = reinterpret_cast <int64_t >(stack[2 ]);
363
- TORCH_WARN (innerKTiles);
364
- TORCH_WARN (group_size);
365
344
AtenTensorHandle scales_and_zeros_ath =
366
345
reinterpret_cast <AtenTensorHandle>(stack[1 ]);
367
346
AtenTensorHandle packed_w_ath = reinterpret_cast <AtenTensorHandle>(stack[0 ]);
@@ -386,68 +365,44 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
386
365
const auto & schema = op.schema ();
387
366
const auto num_returns = schema.returns ().size ();
388
367
const auto num_arguments = schema.arguments ().size ();
389
- TORCH_CHECK (num_arguments==4 );
390
- TORCH_CHECK (num_returns==1 );
391
368
void **ministack = (void **)malloc ((num_arguments + num_returns) * sizeof (void *));
392
369
393
370
for (auto idx = 0 ; idx < num_arguments; idx++) {
394
- TORCH_WARN (idx);
395
371
const c10::IValue& arg = torch::jit::peek (stack, idx, num_arguments);
396
372
if (arg.isInt ()) {
397
373
ministack[idx] = reinterpret_cast <void *>(arg.toInt ());
398
374
} else if (arg.isTensor ()) {
399
- TORCH_WARN (" am tensor!" )
400
375
const at::Tensor& tensor = arg.toTensor ();
401
376
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle (&tensor);
402
377
ministack[idx] = reinterpret_cast <void *>(ath);
403
378
} else {
404
- TORCH_CHECK (false , " Other types of IValues not handled!" );
379
+ TORCH_CHECK (false , " Other types of IValues not yet handled!" );
405
380
}
406
381
}
407
- TORCH_WARN (" done with forloop no problems!" )
408
382
409
383
// second function is going to take a stack of void*, cast them to our
410
384
// schema values for now, and run the function and modify the void* stack
411
385
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout (ministack, num_arguments,
412
386
num_returns);
413
387
414
- // now read the output from the end of the stack and wrap that back into
415
- // IValue from void*?
416
-
417
- AtenTensorHandle out_ath =
418
- reinterpret_cast <AtenTensorHandle>(ministack[num_arguments]);
419
-
420
- free (ministack);
421
-
422
- at::Tensor out =
423
- *torch::aot_inductor::tensor_handle_to_tensor_pointer (out_ath);
424
-
425
- // now pop everything. if we pop earlier, Tensors would go out of scope
388
+ // now pop all inputs on stack. if we pop earlier, Tensors would go out of scope
426
389
// before calling the function
427
390
torch::jit::drop (stack, num_arguments);
428
- torch::jit::push (stack, c10::IValue (out));
429
-
430
- // so above is our stack of IValues, but we cannot have these IValues because
431
- // they are NOT ABI stable! So we need another version of "boxed" with void*s.
432
- // and that is what is going to happen below
433
-
434
- // what the old function used to be:
435
- // int64_t innerKTiles = torch::jit::pop(stack).toInt();
436
- // int64_t group_size = torch::jit::pop(stack).toInt();
437
- // const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
438
- // const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();
439
391
440
- // AtenTensorHandle packed_w_ath =
441
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
442
- // AtenTensorHandle scales_and_zeros_ath =
443
- // torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);
444
-
445
- // AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
446
- // packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
392
+ // read the output from the end of the stack and wrap that back into
393
+ // IValue from void*?
394
+ for (auto idx = 0 ; idx < num_returns; idx++) {
395
+ const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
396
+ if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
397
+ AtenTensorHandle ret_ath = reinterpret_cast <AtenTensorHandle>( ministack[num_arguments + idx]);
398
+ at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer (ret_ath);
399
+ torch::jit::push (stack, c10::IValue (out));
400
+ } else {
401
+ TORCH_CHECK (false , " Only Tensor return types are currently supported!" );
402
+ }
403
+ }
447
404
448
- // at::Tensor out =
449
- // *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
450
- // torch::jit::push(stack, c10::IValue(out));
405
+ free (ministack);
451
406
}
452
407
453
408
0 commit comments