-
Notifications
You must be signed in to change notification settings - Fork 249
/
Copy pathComputeSparseTile.cuh
111 lines (95 loc) · 3.66 KB
/
ComputeSparseTile.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once
#include "SparseSemiStructuredPack.cuh"
#include "StaticSort.h"
#include <cutlass/bfloat16.h>
#include <cutlass/half.h>
#include <cutlass/platform/platform.h>
#include <cutlass/version.h>
// Basic FP8 type definitions
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
// // For FP8 E4M3 format (4 exponent bits, 3 mantissa bits)
// #include <cutlass/float8_e4m3.h>
// // For FP8 E5M2 format (5 exponent bits, 2 mantissa bits)
// #include <cutlass/float8_e5m2.h>
// Given 4x4 values, computes the selected indices that will remain after 2:4
// sparsification, as a bitmask.
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
namespace torchao {
template <typename Element, typename Pointwise> struct TileValueOrderedT {
union {
struct {
Element value;
uint2b_t inner_index;
uint2b_t outer_index;
} parts;
uint32_t raw;
};
CUTLASS_DEVICE bool
operator<(TileValueOrderedT<Element, Pointwise> const &other) const {
return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value);
}
CUTLASS_DEVICE TileValueOrderedT() {}
};
// Operations that we can apply to rank the values
struct IdentityOp {
template <typename T> static T CUTLASS_HOST_DEVICE apply(T const &x) {
return x;
}
};
// Given 1x4 values (a row), computes the selected indices that will remain
// after 2:4 sparsification, as a bitmask. We have 1 constraint: (1) Exactly 2
// values per row ALGO: We use a simple algorithm that selects the 2 largest
// values in the row. NOTE: RF are not indexable, so we shouldn't rely on
// indexing
// values at any point, otherwise they will be stored in local memory.
template <typename Op = IdentityOp> struct LargestValuesRowwise {
template <typename T> static CUTLASS_DEVICE T outOfBoundsFillValue() {
return -cutlass::platform::numeric_limits<T>::infinity();
}
template <typename Tile1x16Accessor>
CUTLASS_DEVICE Indices1x16 operator()(Tile1x16Accessor values) {
using TileValueOrdered =
TileValueOrderedT<typename Tile1x16Accessor::Element, Op>;
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>;
Indices1x16 indices;
TileValuesFragment values_ordered;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
TileValueOrdered &v = values_ordered[i * 4 + j];
v.parts.value = values.at(0, i * 4 + j).get();
v.parts.inner_index = uint2b_t(j);
v.parts.outer_index = uint2b_t(i);
}
}
// Use a sorting network (aka without branches) to avoid
// warp divergence
StaticSort<TileValuesFragment::kElements> sorter;
sorter(values_ordered);
// bitmask to store how many we have selected on a given row
// 0 selected: (numPerRow >> 2*row) = 00 (0)
// 1 selected: (numPerRow >> 2*row) = 01 (1)
// 2 selected: (numPerRow >> 2*row) = 11 (3)
uint32_t numPer1x4Strip = 0;
indices = 0;
// Take as many as we can, starting with the largest values
CUTLASS_PRAGMA_UNROLL
for (int i = values_ordered.size() - 1; i >= 0; i--) {
auto &e = values_ordered[i];
uint32_t rcount = uint2b_t(numPer1x4Strip >> 2 * e.parts.outer_index);
// NOTE: This is more efficient (yet equivalent) to:
// `rcount != 3 && ccount != 3`
bool selected = rcount <= 2;
indices |= selected << (e.parts.inner_index + 4 * e.parts.outer_index);
numPer1x4Strip |= (rcount + selected) << 2 * e.parts.outer_index;
}
return indices;
}
};
template <typename T> void named_algorithms(T callback) {
// default one
callback(LargestValuesRowwise<IdentityOp>(), "");
}
} // namespace torchao