Skip to content

Commit be650e9

Browse files
committed
Sparsity api_ref
1 parent 32d9b0b commit be650e9

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

docs/source/api_ref_sparsity.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ torchao.sparsity
1212

1313
WandaSparsifier
1414
PerChannelNormObserver
15-
apply_sparse_semi_structured
1615
apply_fake_sparsity
17-
18-
16+
sparsify_
17+
semi_sparse_weight
18+
int8_dynamic_activation_int8_semi_sparse_weight

torchao/sparsity/sparse_api.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def sparsify_(
4343
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
4444
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
4545
) -> torch.nn.Module:
46-
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
46+
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`.
4747
This function is essentially the same as quantize, put for sparsity subclasses.
4848
4949
Currently, we support three options for sparsity:
@@ -54,26 +54,26 @@ def sparsify_(
5454
Args:
5555
model (torch.nn.Module): input model
5656
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
57-
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
58-
the weight of the module
57+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module
5958
60-
Example::
61-
import torch
62-
import torch.nn as nn
63-
from torchao.sparsity import sparsify_
59+
**Example:**
60+
::
61+
import torch
62+
import torch.nn as nn
63+
from torchao.sparsity import sparsify_
6464
65-
def filter_fn(module: nn.Module, fqn: str) -> bool:
66-
return isinstance(module, nn.Linear)
65+
def filter_fn(module: nn.Module, fqn: str) -> bool:
66+
return isinstance(module, nn.Linear)
6767
68-
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
68+
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
6969
70-
# for 2:4 sparsity
71-
from torchao.sparse_api import semi_sparse_weight
72-
m = sparsify_(m, semi_sparse_weight(), filter_fn)
70+
# for 2:4 sparsity
71+
from torchao.sparse_api import semi_sparse_weight
72+
m = sparsify_(m, semi_sparse_weight(), filter_fn)
7373
74-
# for int8 dynamic quantization + 2:4 sparsity
75-
from torchao.dtypes import SemiSparseLayout
76-
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
74+
# for int8 dynamic quantization + 2:4 sparsity
75+
from torchao.dtypes import SemiSparseLayout
76+
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
7777
"""
7878
_replace_with_custom_fn_if_matches_filter(
7979
model,

0 commit comments

Comments
 (0)