Skip to content

Commit 1521b07

Browse files
committed
Sparsity developer_notes
1 parent 65f24a6 commit 1521b07

File tree

1 file changed

+16
-204
lines changed

1 file changed

+16
-204
lines changed

docs/source/sparsity.rst

+16-204
Original file line numberDiff line numberDiff line change
@@ -3,182 +3,6 @@ Sparsity
33

44
Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1).
55

6-
Benchmarks
7-
==========
8-
9-
segment-anything-fast
10-
^^^^^^^^^^^^^^^^^^^^^
11-
12-
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
13-
14-
Overall, we found that accelerating the MLP linear layers provied the most speedups (\ ``lin1``\ , ``lin2``\ ), while mitigating accuracy loss.
15-
16-
Applying sparsity to the attention linear layers led to a slower model, likely due to two reasons:
17-
18-
19-
* We cannot fuse into our semi-structured sparse matmul with torch.compile.
20-
* The speedups we observe for sparse matmul depend on the matmul shapes, and the attention matmuls are smaller than the MLP ones.
21-
22-
We were also are able to compose int8 dynamic quantization with 2:4 sparsity for futher speedups.
23-
24-
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
25-
26-
The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and ``bfloat16`` dtype, with ``torch.compile="max_autotune"``\ :
27-
28-
.. list-table::
29-
:header-rows: 1
30-
31-
* - Model Type
32-
- Technique
33-
- img/s
34-
- memory (MiB)
35-
- mIoU (coco2017 val)
36-
- relative speedup
37-
- relative accuracy
38-
* - ViT-h
39-
- baseline (bfloat16, max-autotune)
40-
- 22.75
41-
- 15172
42-
- 0.5811
43-
-
44-
-
45-
* -
46-
- int8 dynamic quant (attn + mlp)
47-
- 24.91
48-
- 15154
49-
- 0.5822
50-
- **1.09x**
51-
- **100.19%**
52-
* -
53-
- 2:4 sparsity (mlp only)
54-
- 24.81
55-
- 15632
56-
- 0.5672
57-
- **1.10x**
58-
- **97.61%**
59-
* -
60-
- 2:4 sparsity (attn + mlp)
61-
- 24.30
62-
- 13429
63-
- 0.5306
64-
- **1.07x**
65-
- **91.31%**
66-
* -
67-
- int8 dynamic quant (attn)\ :raw-html-m2r:`<br>`\ int8 dynamic quant + 2:4 sparsity (mlp lin1)\ :raw-html-m2r:`<br>`\ 2:4 sparsity (mlp lin2)
68-
- 26.46
69-
- 14865
70-
- 0.5668
71-
- **1.16x**
72-
- **97.54%**
73-
74-
75-
To reproduce our benchmarks please follow these `instructions </torchao/_models/sam/README.md>`_.
76-
77-
LLama3
78-
^^^^^^
79-
80-
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation when using the sparse marlin kernel @Diogo-V added.
81-
82-
.. list-table::
83-
:header-rows: 1
84-
85-
* - Model
86-
- Technique
87-
- Tokens/Second
88-
- Memory Bandwidth (GB/s)
89-
- Peak Memory (GB)
90-
- Model Size (GB)
91-
* - Llama-3-8B
92-
- Base (bfloat16)
93-
- 95.64
94-
- 1435.54
95-
- 16.43
96-
- 15.01
97-
* -
98-
- int8wo
99-
- 153.03
100-
- 1150.80
101-
- 10.42
102-
- 7.52
103-
* -
104-
- int4wo-64
105-
- 180.80
106-
- 763.33
107-
- 6.88
108-
- 4.22
109-
* -
110-
- int4wo-64-sparse-marlin
111-
- 226.02
112-
- 689.20
113-
- 5.32
114-
- 3.05
115-
116-
117-
These benchmarks were also ran on a NVIDIA-A100-80GB.
118-
119-
Supported APIs
120-
==============
121-
122-
123-
.. image:: /docs/static/supported_sparsity_patterns.png
124-
:target: /docs/static/supported_sparsity_patterns.png
125-
:alt: support_matrix
126-
127-
128-
Sparse Marlin 2:4
129-
^^^^^^^^^^^^^^^^^
130-
131-
Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity, improving performance in matrix multiplication and accumulation. Full documentation can be found `here <https://github.com/IST-DASLab/Sparse-Marlin>`_.
132-
133-
.. code-block:: py
134-
135-
from torchao.quantization.quant_api import quantize_, int4_weight_only
136-
from torchao.dtypes import MarlinSparseLayout
137-
138-
# Your FP16 model
139-
model = model.cuda().half()
140-
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
141-
142-
Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop
143-
the necessary supporting flows in torchao.
144-
145-
int8 dynamic quant + 2:4 sparasity
146-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
147-
148-
We support composing int8 dynaic quantization with 2:4 sparsity. We fuse one of the scalar dequant multiplications into our cuSPARSELt sparse mm in order to remain performant.
149-
150-
.. code-block:: py
151-
152-
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
153-
from torchao.dtypes import SemiSparseLayout
154-
155-
model = model.cuda()
156-
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
157-
158-
2:4 sparsity
159-
^^^^^^^^^^^^
160-
161-
.. code-block:: py
162-
163-
from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight
164-
from torchao.dtypes import SemiSparseLayout
165-
166-
model = model.cuda()
167-
sparsify_(model, semi_sparse_weight())
168-
169-
Block sparsity (prototype)
170-
^^^^^^^^^^^^^^^^^^^^^^^^^^
171-
172-
We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads.
173-
174-
.. code-block:: py
175-
176-
from torchao.sparsity.sparse_api import sparsify_
177-
from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight
178-
179-
model = model.cuda()
180-
sparsify_(model, block_sparse_weight())
181-
1826
Goal
1837
====
1848

@@ -229,8 +53,7 @@ The handoff point between these two pieces are sparse weights stored in a dense
22953
This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research.
23054

23155

232-
.. image:: /docs/static/pruning_ecosystem_diagram.png
233-
:target: /docs/static/pruning_ecosystem_diagram.png
56+
.. image:: ../static/pruning_ecosystem_diagram.png
23457
:alt: pruning_flow
23558

23659

@@ -286,8 +109,7 @@ Note that this section focuses on **pruning**\ , instead of **sparse training**.
286109
Roughly, the flow for achieving a more performant pruned model looks like this:
287110

288111

289-
.. image:: /docs/static/pruning_flow.png
290-
:target: /docs/static/pruning_flow.png
112+
.. image:: ../static/pruning_flow.png
291113
:alt: flow
292114

293115

@@ -643,19 +465,16 @@ The specific backend hardware and its corresponding sparsity pattern, as well as
643465
</tr>
644466
</table>
645467

646-
647-
*Fig 2.3: unstructured sparsity*
468+
<i>Fig 2.3: unstructured sparsity</i>
648469

649470
</td>
650471
</tr>
651-
:raw-html-m2r:`<tr>`
652-
:raw-html-m2r:`<td>`\ 2:4 Semi-Structured
472+
<tr>
473+
<td> 2:4 Semi-Structured
653474

654475
</td>
655-
:raw-html-m2r:`<td>`
656-
476+
<td>
657477

658-
.. raw:: html
659478

660479
<table>
661480
<tr>
@@ -733,18 +552,15 @@ The specific backend hardware and its corresponding sparsity pattern, as well as
733552
</table>
734553

735554

736-
*Fig 2.4: 2:4 semi-structured sparsity*
555+
<i>Fig 2.4: 2:4 semi-structured sparsity</i>
737556

738557
</td>
739558
</tr>
740-
:raw-html-m2r:`<tr>`
741-
:raw-html-m2r:`<td>`\ Block Sparsity
559+
<tr>
560+
<td> Block Sparsity
742561

743562
</td>
744-
:raw-html-m2r:`<td>`
745-
746-
747-
.. raw:: html
563+
<td>
748564

749565
<table>
750566
<tr>
@@ -822,19 +638,17 @@ The specific backend hardware and its corresponding sparsity pattern, as well as
822638
</table>
823639

824640

825-
*Fig 2.5: 4x4 block-wise structured sparsity*
641+
<i>Fig 2.5: 4x4 block-wise structured sparsity</i>
826642

827643
</td>
828644
</tr>
829-
:raw-html-m2r:`<tr>`
830-
:raw-html-m2r:`<td>`\ Structured Sparsity
645+
<tr>
646+
<td> Structured Sparsity
831647

832648
</td>
833-
:raw-html-m2r:`<td>`
649+
<td>
834650

835651

836-
.. raw:: html
837-
838652
<table>
839653
<tr>
840654
<td>1
@@ -909,12 +723,10 @@ The specific backend hardware and its corresponding sparsity pattern, as well as
909723
</td>
910724
</tr>
911725
</table>
912-
913-
914-
*Fig 2.6: row-wise structured sparsity*
726+
<i>Fig 2.6: row-wise structured sparsity</i>
915727

916728
</td>
917729
</tr>
918-
</table>
730+
</table>
919731

920732
*Table 4.4: Description of some common sparsity patterns.*

0 commit comments

Comments
 (0)