30
30
#include < iostream>
31
31
32
32
#include " base.h"
33
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
34
- #include " mem.h"
35
- #endif
33
+ #include " mem.h"
36
34
37
35
template <typename T>
38
36
inline std::string str (T x) {
@@ -41,8 +39,6 @@ inline std::string str(T x) {
41
39
42
40
namespace torchao {
43
41
44
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
45
-
46
42
using I4 = Vec<int , 4 >;
47
43
// Matrix fragments for tensor core instructions; their precise layout is
48
44
// documented here:
@@ -208,6 +204,8 @@ __global__ void Marlin_QQQ(
208
204
int prob_k, // reduction dimension k
209
205
int * locks // extra global storage for barrier synchronization
210
206
) {
207
+ // host code or device code with SM >= 80. Marlin only supports SM >= 80.
208
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
211
209
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
212
210
// same size, which might involve multiple column "slices" (of width 16 *
213
211
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
@@ -855,47 +853,8 @@ __global__ void Marlin_QQQ(
855
853
}
856
854
}
857
855
}
858
- }
859
-
860
- #else
861
-
862
- template <const int threads, // number of threads in a threadblock
863
- const int thread_m_blocks, // number of 16x16 blocks in the m
864
- // dimension (batchsize) of the
865
- // threadblock
866
- const int thread_n_blocks, // same for n dimension (output)
867
- const int thread_k_blocks, // same for k dimension (reduction)
868
- const int stages, // number of stages for the async global->shared
869
- // fetch pipeline
870
- const int group_blocks = -1 // number of consecutive 16x16 blocks
871
- // with a separate quantization scale
872
- >
873
- __global__ void Marlin_QQQ (
874
- const int4 * __restrict__ A, // int8 input matrix of shape mxk
875
- const int4 * __restrict__ B, // 4bit quantized weight matrix of shape kxn
876
- int4 * __restrict__ C, // int32 global_reduce buffer of shape
877
- // (max_par*16*4)xn, as int8 tensor core's output is
878
- // int32 dtype
879
- int4 * __restrict__ D, // fp16 output buffer of shape mxn
880
- const float * __restrict__ s_tok, // fp32 activation per-token quantization
881
- // scales of shape mx1
882
- const int4 * __restrict__ s_ch, // fp32 weight per-channel quantization
883
- // scales of shape 1xn
884
- const int4 * __restrict__ s_group, // fp16 weight per-group quantization
885
- // scales of shape (k/groupsize)xn, when
886
- // group_blocks=-1, it should be nullptr
887
- int prob_m, // batch dimension m
888
- int prob_n, // output dimension n
889
- int prob_k, // reduction dimension k
890
- int * locks // extra global storage for barrier synchronization
891
- ) {
892
- // Marlin is not implemented yet for SM < 8.0
893
- TORCH_CHECK_NOT_IMPLEMENTED (
894
- false , " marlin_qqq_gemm(..) requires CUDA_ARCH >= 8.0" );
895
- return ;
896
- }
897
-
898
856
#endif
857
+ }
899
858
900
859
// 8 warps are a good choice since every SM has 4 schedulers and having more
901
860
// than 1 warp per schedule allows some more latency hiding. At the same time,
@@ -1132,6 +1091,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
1132
1091
torch::Tensor const & s_group,
1133
1092
torch::Tensor& workspace, int64_t size_m,
1134
1093
int64_t size_n, int64_t size_k) {
1094
+ const auto dprops = at::cuda::getCurrentDeviceProperties ();
1095
+ if (dprops->major < 8 ) {
1096
+ TORCH_CHECK (false , __func__, " requires SM >= 8.0. Current device is SM" ,
1097
+ dprops->major , " ." , dprops->minor );
1098
+ }
1099
+
1135
1100
// Verify M
1136
1101
TORCH_CHECK (size_m == a.size (0 ),
1137
1102
" Shape mismatch: a.size(0) = " + str (a.size (0 )) +
0 commit comments