Skip to content

Commit cf58de5

Browse files
eqypytorchmergebot
authored andcommitted
[cuBLASLt][Memtracker] Allocate temprorary cuBLASLt workspaces using tensors rather than going to the caching allocator directly (pytorch#139442)
CC @zdevito @janeyx99 This isn't ideal but cuBLASLt workspaces are not currently cached, so this additional untracked allocation will cause `test_cuda_tracker_equivalence` to fail with a large enough workspace size e.g., `CUBLAS_LT_WORKSPACE_SIZE=32768`. One solution is to just use byte-tensors for the workspace instead of going directly to the caching allocator. Pull Request resolved: pytorch#139442 Approved by: https://github.com/Aidyn-A, https://github.com/albanD, https://github.com/janeyx99
1 parent b7b5657 commit cf58de5

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
394394
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
395395
#endif
396396

397-
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
398-
auto workspace = allocator.allocate(workspaceSize);
399-
TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
397+
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
400398

401399
cublasLtMatmulHeuristicResult_t heuristicResult = {};
402400
int returnedResult = 0;
@@ -429,7 +427,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
429427
c,
430428
Cdesc.descriptor(),
431429
&heuristicResult.algo,
432-
workspace.mutable_get(),
430+
workspace.mutable_data_ptr(),
433431
workspaceSize,
434432
at::cuda::getCurrentCUDAStream());
435433
TORCH_CHECK(
@@ -1290,9 +1288,7 @@ void gemm_and_bias(
12901288
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
12911289
#endif
12921290

1293-
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1294-
auto workspace = allocator.allocate(workspaceSize);
1295-
TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1291+
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
12961292

12971293
cublasLtMatmulHeuristicResult_t heuristicResult = {};
12981294
int returnedResult = 0;
@@ -1326,7 +1322,7 @@ void gemm_and_bias(
13261322
result_ptr,
13271323
Cdesc.descriptor(),
13281324
&heuristicResult.algo,
1329-
workspace.mutable_get(),
1325+
workspace.mutable_data_ptr(),
13301326
workspaceSize,
13311327
at::cuda::getCurrentCUDAStream());
13321328
TORCH_CHECK(
@@ -1474,9 +1470,7 @@ void scaled_gemm(
14741470
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
14751471
}
14761472
size_t workspaceSize = _getWorkspaceSize();
1477-
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1478-
auto workspace = allocator.allocate(workspaceSize);
1479-
TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1473+
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
14801474

14811475
CuBlasLtMatmulPreference preference;
14821476
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
@@ -1560,7 +1554,7 @@ void scaled_gemm(
15601554
result_ptr,
15611555
Ddesc.descriptor(),
15621556
&heuristicResult.algo,
1563-
workspace.mutable_get(),
1557+
workspace.mutable_data_ptr(),
15641558
workspaceSize,
15651559
at::cuda::getCurrentCUDAStream());
15661560
TORCH_CHECK(
@@ -1631,8 +1625,8 @@ void int8_gemm(
16311625
CuBlasLtMatmulPreference preference;
16321626
size_t workspaceSize = _getWorkspaceSize();
16331627
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1634-
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1635-
auto workspace = allocator.allocate(workspaceSize);
1628+
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
1629+
16361630
cublasLtMatmulHeuristicResult_t heuristicResult = {};
16371631
int returnedResult = 0;
16381632
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@@ -1670,7 +1664,7 @@ void int8_gemm(
16701664
nullptr, // Heuristics don't seem to work for int8
16711665
#endif
16721666
#ifdef USE_ROCM
1673-
workspace.mutable_get(),
1667+
workspace.mutable_data_ptr(),
16741668
#else
16751669
nullptr, // Non-zero workspace doesn't seem to work.
16761670
#endif

0 commit comments

Comments
 (0)