You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/sparsity.rst
+16-204
Original file line number
Diff line number
Diff line change
@@ -3,182 +3,6 @@ Sparsity
3
3
4
4
Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1).
5
5
6
-
Benchmarks
7
-
==========
8
-
9
-
segment-anything-fast
10
-
^^^^^^^^^^^^^^^^^^^^^
11
-
12
-
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
13
-
14
-
Overall, we found that accelerating the MLP linear layers provied the most speedups (\ ``lin1``\ , ``lin2``\ ), while mitigating accuracy loss.
15
-
16
-
Applying sparsity to the attention linear layers led to a slower model, likely due to two reasons:
17
-
18
-
19
-
* We cannot fuse into our semi-structured sparse matmul with torch.compile.
20
-
* The speedups we observe for sparse matmul depend on the matmul shapes, and the attention matmuls are smaller than the MLP ones.
21
-
22
-
We were also are able to compose int8 dynamic quantization with 2:4 sparsity for futher speedups.
23
-
24
-
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
25
-
26
-
The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and ``bfloat16`` dtype, with ``torch.compile="max_autotune"``\ :
To reproduce our benchmarks please follow these `instructions </torchao/_models/sam/README.md>`_.
76
-
77
-
LLama3
78
-
^^^^^^
79
-
80
-
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation when using the sparse marlin kernel @Diogo-V added.
81
-
82
-
.. list-table::
83
-
:header-rows: 1
84
-
85
-
* - Model
86
-
- Technique
87
-
- Tokens/Second
88
-
- Memory Bandwidth (GB/s)
89
-
- Peak Memory (GB)
90
-
- Model Size (GB)
91
-
* - Llama-3-8B
92
-
- Base (bfloat16)
93
-
- 95.64
94
-
- 1435.54
95
-
- 16.43
96
-
- 15.01
97
-
* -
98
-
- int8wo
99
-
- 153.03
100
-
- 1150.80
101
-
- 10.42
102
-
- 7.52
103
-
* -
104
-
- int4wo-64
105
-
- 180.80
106
-
- 763.33
107
-
- 6.88
108
-
- 4.22
109
-
* -
110
-
- int4wo-64-sparse-marlin
111
-
- 226.02
112
-
- 689.20
113
-
- 5.32
114
-
- 3.05
115
-
116
-
117
-
These benchmarks were also ran on a NVIDIA-A100-80GB.
Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity, improving performance in matrix multiplication and accumulation. Full documentation can be found `here <https://github.com/IST-DASLab/Sparse-Marlin>`_.
132
-
133
-
.. code-block:: py
134
-
135
-
from torchao.quantization.quant_api import quantize_, int4_weight_only
Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop
143
-
the necessary supporting flows in torchao.
144
-
145
-
int8 dynamic quant + 2:4 sparasity
146
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
147
-
148
-
We support composing int8 dynaic quantization with 2:4 sparsity. We fuse one of the scalar dequant multiplications into our cuSPARSELt sparse mm in order to remain performant.
149
-
150
-
.. code-block:: py
151
-
152
-
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight
164
-
from torchao.dtypes import SemiSparseLayout
165
-
166
-
model = model.cuda()
167
-
sparsify_(model, semi_sparse_weight())
168
-
169
-
Block sparsity (prototype)
170
-
^^^^^^^^^^^^^^^^^^^^^^^^^^
171
-
172
-
We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads.
173
-
174
-
.. code-block:: py
175
-
176
-
from torchao.sparsity.sparse_api import sparsify_
177
-
from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight
178
-
179
-
model = model.cuda()
180
-
sparsify_(model, block_sparse_weight())
181
-
182
6
Goal
183
7
====
184
8
@@ -229,8 +53,7 @@ The handoff point between these two pieces are sparse weights stored in a dense
229
53
This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research.
0 commit comments