Skip to content

Commit 8afd10e

Browse files
authored
Fix compile issue for Marin qqq on sm<8.0 (#1651)
* fix compile guard * remove guard on header file
1 parent 1a4c8f9 commit 8afd10e

File tree

1 file changed

+10
-45
lines changed

1 file changed

+10
-45
lines changed

torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu

+10-45
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
#include <iostream>
3131

3232
#include "base.h"
33-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
34-
#include "mem.h"
35-
#endif
33+
#include "mem.h"
3634

3735
template <typename T>
3836
inline std::string str(T x) {
@@ -41,8 +39,6 @@ inline std::string str(T x) {
4139

4240
namespace torchao {
4341

44-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
45-
4642
using I4 = Vec<int, 4>;
4743
// Matrix fragments for tensor core instructions; their precise layout is
4844
// documented here:
@@ -208,6 +204,8 @@ __global__ void Marlin_QQQ(
208204
int prob_k, // reduction dimension k
209205
int* locks // extra global storage for barrier synchronization
210206
) {
207+
// host code or device code with SM >= 80. Marlin only supports SM >= 80.
208+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
211209
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
212210
// same size, which might involve multiple column "slices" (of width 16 *
213211
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
@@ -855,47 +853,8 @@ __global__ void Marlin_QQQ(
855853
}
856854
}
857855
}
858-
}
859-
860-
#else
861-
862-
template <const int threads, // number of threads in a threadblock
863-
const int thread_m_blocks, // number of 16x16 blocks in the m
864-
// dimension (batchsize) of the
865-
// threadblock
866-
const int thread_n_blocks, // same for n dimension (output)
867-
const int thread_k_blocks, // same for k dimension (reduction)
868-
const int stages, // number of stages for the async global->shared
869-
// fetch pipeline
870-
const int group_blocks = -1 // number of consecutive 16x16 blocks
871-
// with a separate quantization scale
872-
>
873-
__global__ void Marlin_QQQ(
874-
const int4* __restrict__ A, // int8 input matrix of shape mxk
875-
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
876-
int4* __restrict__ C, // int32 global_reduce buffer of shape
877-
// (max_par*16*4)xn, as int8 tensor core's output is
878-
// int32 dtype
879-
int4* __restrict__ D, // fp16 output buffer of shape mxn
880-
const float* __restrict__ s_tok, // fp32 activation per-token quantization
881-
// scales of shape mx1
882-
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
883-
// scales of shape 1xn
884-
const int4* __restrict__ s_group, // fp16 weight per-group quantization
885-
// scales of shape (k/groupsize)xn, when
886-
// group_blocks=-1, it should be nullptr
887-
int prob_m, // batch dimension m
888-
int prob_n, // output dimension n
889-
int prob_k, // reduction dimension k
890-
int* locks // extra global storage for barrier synchronization
891-
) {
892-
// Marlin is not implemented yet for SM < 8.0
893-
TORCH_CHECK_NOT_IMPLEMENTED(
894-
false, "marlin_qqq_gemm(..) requires CUDA_ARCH >= 8.0");
895-
return;
896-
}
897-
898856
#endif
857+
}
899858

900859
// 8 warps are a good choice since every SM has 4 schedulers and having more
901860
// than 1 warp per schedule allows some more latency hiding. At the same time,
@@ -1132,6 +1091,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
11321091
torch::Tensor const& s_group,
11331092
torch::Tensor& workspace, int64_t size_m,
11341093
int64_t size_n, int64_t size_k) {
1094+
const auto dprops = at::cuda::getCurrentDeviceProperties();
1095+
if (dprops->major < 8) {
1096+
TORCH_CHECK(false, __func__, "requires SM >= 8.0. Current device is SM",
1097+
dprops->major, ".", dprops->minor);
1098+
}
1099+
11351100
// Verify M
11361101
TORCH_CHECK(size_m == a.size(0),
11371102
"Shape mismatch: a.size(0) = " + str(a.size(0)) +

0 commit comments

Comments
 (0)