Skip to content

Commit f6c6048

Browse files
cpuhrschpytorchmergebot
authored andcommitted
Use CUTLASS GEMM for NT bmm (pytorch#85894)
Copy of pytorch#85710 Pull Request resolved: pytorch#85894 Approved by: https://github.com/drisspg
1 parent 80790ec commit f6c6048

10 files changed

+400
-40
lines changed

BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ cu_library(
429429
"@cuda//:cublas",
430430
"@cuda//:cufft",
431431
"@cuda//:cusparse",
432+
"@cutlass",
432433
],
433434
alwayslink = True,
434435
)
@@ -1673,6 +1674,7 @@ cc_library(
16731674
] + if_cuda([
16741675
":torch_distributed_cuda",
16751676
"@cuda//:nvToolsExt",
1677+
"@cutlass",
16761678
]),
16771679
alwayslink = True,
16781680
)

WORKSPACE

+6
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ new_local_repository(
8484
path = "third_party/eigen",
8585
)
8686

87+
new_local_repository(
88+
name = "cutlass",
89+
build_file = "//third_party:cutlass.BUILD",
90+
path = "third_party/cutlass",
91+
)
92+
8793
new_local_repository(
8894
name = "fbgemm",
8995
build_file = "//third_party:fbgemm/BUILD.bazel",

aten/src/ATen/CMakeLists.txt

+1-3
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,7 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
433433
endif()
434434

435435
if(USE_CUDA AND NOT USE_ROCM)
436-
if(USE_FLASH_ATTENTION)
437-
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
438-
endif()
436+
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
439437
if($ENV{ATEN_STATIC_CUDA})
440438
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
441439
${CUDA_LIBRARIES}

aten/src/ATen/native/native_functions.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,8 @@
11741174
dispatch:
11751175
SparseCPU: bmm_sparse_cpu
11761176
SparseCUDA: bmm_sparse_cuda
1177-
NestedTensorCPU, NestedTensorCUDA: bmm_nested
1177+
NestedTensorCPU: bmm_nested
1178+
NestedTensorCUDA: bmm_nested_cuda
11781179
tags: canonical
11791180

11801181
- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)

aten/src/ATen/native/nested/NestedTensorUtils.cpp

-22
Original file line numberDiff line numberDiff line change
@@ -108,27 +108,5 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
108108
return splits;
109109
}
110110

111-
std::vector<IntArrayRef> NestedTensor_get_sizes(
112-
const NestedTensorImpl* self_ptr) {
113-
int64_t ntensors = self_ptr->size(0);
114-
std::vector<IntArrayRef> sizes(ntensors);
115-
if (ntensors == 0) {
116-
return sizes;
117-
}
118-
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
119-
int64_t orig_dim = sizemat.size(1);
120-
// nesting scalars has empty sizes
121-
if (orig_dim == 0) {
122-
return sizes;
123-
}
124-
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
125-
126-
for (const auto i : c10::irange(ntensors)) {
127-
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
128-
sizemat_ptr += orig_dim;
129-
}
130-
return sizes;
131-
}
132-
133111
} // namespace native
134112
} // namespace at

aten/src/ATen/native/nested/NestedTensorUtils.h

+22-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,28 @@ inline at::Tensor create_nested_view_tensor(
9797
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
9898

9999
// The sizes of the underlying tensors
100-
std::vector<IntArrayRef> NestedTensor_get_sizes(
101-
const NestedTensorImpl* self_ptr);
100+
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
101+
const NestedTensorImpl* self_ptr) {
102+
int64_t ntensors = self_ptr->size(0);
103+
std::vector<IntArrayRef> sizes(ntensors);
104+
if (ntensors == 0) {
105+
return sizes;
106+
}
107+
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
108+
int64_t orig_dim = sizemat.size(1);
109+
// nesting scalars has empty sizes
110+
if (orig_dim == 0) {
111+
return sizes;
112+
}
113+
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
114+
115+
for (const auto i : c10::irange(ntensors)) {
116+
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
117+
sizemat_ptr += orig_dim;
118+
}
119+
return sizes;
120+
}
121+
102122

103123
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
104124
const NestedTensorImpl& nt);

0 commit comments

Comments
 (0)