@@ -394,9 +394,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
394
394
preference.setAttribute (CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
395
395
#endif
396
396
397
- auto & allocator = *::c10::cuda::CUDACachingAllocator::get ();
398
- auto workspace = allocator.allocate (workspaceSize);
399
- TORCH_CHECK (workspace.get () != nullptr , " OOM trying to allocate workspace for cublaslt" );
397
+ auto workspace = at::empty (workspaceSize, at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
400
398
401
399
cublasLtMatmulHeuristicResult_t heuristicResult = {};
402
400
int returnedResult = 0 ;
@@ -429,7 +427,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
429
427
c,
430
428
Cdesc.descriptor (),
431
429
&heuristicResult.algo ,
432
- workspace.mutable_get (),
430
+ workspace.mutable_data_ptr (),
433
431
workspaceSize,
434
432
at::cuda::getCurrentCUDAStream ());
435
433
TORCH_CHECK (
@@ -1290,9 +1288,7 @@ void gemm_and_bias(
1290
1288
preference.setAttribute (CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
1291
1289
#endif
1292
1290
1293
- auto & allocator = *::c10::cuda::CUDACachingAllocator::get ();
1294
- auto workspace = allocator.allocate (workspaceSize);
1295
- TORCH_CHECK (workspace.get () != nullptr , " OOM trying to allocate workspace for cublaslt" );
1291
+ auto workspace = at::empty (workspaceSize, at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
1296
1292
1297
1293
cublasLtMatmulHeuristicResult_t heuristicResult = {};
1298
1294
int returnedResult = 0 ;
@@ -1326,7 +1322,7 @@ void gemm_and_bias(
1326
1322
result_ptr,
1327
1323
Cdesc.descriptor (),
1328
1324
&heuristicResult.algo ,
1329
- workspace.mutable_get (),
1325
+ workspace.mutable_data_ptr (),
1330
1326
workspaceSize,
1331
1327
at::cuda::getCurrentCUDAStream ());
1332
1328
TORCH_CHECK (
@@ -1474,9 +1470,7 @@ void scaled_gemm(
1474
1470
computeDesc.setAttribute (CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType (bias_dtype));
1475
1471
}
1476
1472
size_t workspaceSize = _getWorkspaceSize ();
1477
- auto & allocator = *::c10::cuda::CUDACachingAllocator::get ();
1478
- auto workspace = allocator.allocate (workspaceSize);
1479
- TORCH_CHECK (workspace.get () != nullptr , " OOM trying to allocate workspace for cublaslt" );
1473
+ auto workspace = at::empty (workspaceSize, at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
1480
1474
1481
1475
CuBlasLtMatmulPreference preference;
1482
1476
preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
@@ -1560,7 +1554,7 @@ void scaled_gemm(
1560
1554
result_ptr,
1561
1555
Ddesc.descriptor (),
1562
1556
&heuristicResult.algo ,
1563
- workspace.mutable_get (),
1557
+ workspace.mutable_data_ptr (),
1564
1558
workspaceSize,
1565
1559
at::cuda::getCurrentCUDAStream ());
1566
1560
TORCH_CHECK (
@@ -1631,8 +1625,8 @@ void int8_gemm(
1631
1625
CuBlasLtMatmulPreference preference;
1632
1626
size_t workspaceSize = _getWorkspaceSize ();
1633
1627
preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1634
- auto & allocator = *:: c10::cuda::CUDACachingAllocator::get ( );
1635
- auto workspace = allocator. allocate (workspaceSize);
1628
+ auto workspace = at::empty (workspaceSize, at::TensorOptions (). dtype (at:: kByte ). device (at:: kCUDA ) );
1629
+
1636
1630
cublasLtMatmulHeuristicResult_t heuristicResult = {};
1637
1631
int returnedResult = 0 ;
1638
1632
TORCH_CUDABLAS_CHECK (cublasLtMatmulAlgoGetHeuristic (
@@ -1670,7 +1664,7 @@ void int8_gemm(
1670
1664
nullptr , // Heuristics don't seem to work for int8
1671
1665
#endif
1672
1666
#ifdef USE_ROCM
1673
- workspace.mutable_get (),
1667
+ workspace.mutable_data_ptr (),
1674
1668
#else
1675
1669
nullptr , // Non-zero workspace doesn't seem to work.
1676
1670
#endif
0 commit comments