Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit e1c5fe1

Browse files
vkuzofacebook-github-bot
authored andcommitted
switch argument order to module_filter_fn (#328)
Summary: Pull Request resolved: #328 before: `module_filter_fn(fqn, mod)` after: `module_filter_fn(mod, fqn)` This is to better match the format of similar functions in `torchao`. Reviewed By: weifengpy Differential Revision: D60195664 fbshipit-source-id: 271a7f8e52f30d69237a87d4043bf09682014a23
1 parent 8650582 commit e1c5fe1

File tree

4 files changed

+13
-15
lines changed

4 files changed

+13
-15
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
4444
m = Model(...)
4545

4646
# optional: filter modules from being eligible for float8 conversion
47-
def module_filter_fn(fqn: str, mod: torch.nn.Module):
47+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
4848
# don't convert the output module
4949
if fqn == "output":
5050
return False
@@ -91,9 +91,9 @@ from float8_experimental.float8_linear import TensorScalingType
9191
# create model
9292
m = Model(...)
9393

94-
# optional: configure for compatibility with FSDP. Note that workarounds
94+
# optional: configure for compatibility with FSDP. Note that workarounds
9595
# gated with config.enable_amax_init and
96-
# config.enable_pre_and_post_forward are needed for
96+
# config.enable_pre_and_post_forward are needed for
9797
# autocast + compile + FSDP + float8 to work
9898
from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig
9999
config = Float8LinearConfig(

float8_experimental/float8_linear_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def swap_linear_layers(
6060
module: nn.Module,
6161
from_float_func: Callable[[nn.Linear], nn.Linear],
6262
*,
63-
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
63+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
6464
) -> Optional[nn.Module]:
6565
"""
6666
Generic function to swap linear layers in a module with a new type of linear layer.
@@ -74,13 +74,13 @@ def swap_linear_layers(
7474
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
7575
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
7676
that pass the filter function will be swapped. The inputs to the
77-
filter function are the FQN and module instance.
77+
filter function are the module instance, and the FQN.
7878
7979
Returns:
8080
nn.Module: The modified module with swapped linear layers.
8181
"""
8282
if isinstance(module, nn.Linear) and (
83-
module_filter_fn is None or module_filter_fn("", module)
83+
module_filter_fn is None or module_filter_fn(module, "")
8484
):
8585
if len(list(module.children())) > 0:
8686
raise AssertionError(
@@ -109,9 +109,7 @@ def post_order_traversal(
109109
post_order_traversal(child_module, new_fqn, module)
110110

111111
if isinstance(module, nn.Linear) and (
112-
# linear_layer_filter is None or linear_layer_filter(module)
113-
module_filter_fn is None
114-
or module_filter_fn(cur_fqn, module)
112+
module_filter_fn is None or module_filter_fn(module, cur_fqn)
115113
):
116114
assert (
117115
parent_module is not None
@@ -127,7 +125,7 @@ def post_order_traversal(
127125
def swap_linear_with_float8_linear(
128126
module: nn.Module,
129127
*,
130-
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
128+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
131129
config: Float8LinearConfig = None,
132130
) -> Optional[nn.Module]:
133131
"""
@@ -137,7 +135,7 @@ def swap_linear_with_float8_linear(
137135
module: Module to modify.
138136
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
139137
that pass the filter function will be swapped. The inputs to the
140-
filter function are the FQN and module instance.
138+
filter function are the module instance and the FQN.
141139
config (Float8LinearConfig): configuration for conversion to float8
142140
143141
Returns:

float8_experimental/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def quantize_to_float8(
213213
module: nn.Module,
214214
quant_config: QuantConfig,
215215
*,
216-
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
216+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
217217
use_fast_accum: bool = True,
218218
) -> Optional[nn.Module]:
219219
"""
@@ -228,7 +228,7 @@ def quantize_to_float8(
228228
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229229
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
230230
that pass the filter function will be swapped. The inputs to the
231-
filter function are the FQN and module instance.
231+
filter function are the module instance and the FQN.
232232
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
233233
234234
Returns:

test/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def __init__(self, dim: int):
649649

650650
size_limit = 32
651651

652-
def module_filter_fn(fqn, mod):
652+
def module_filter_fn(mod, fqn):
653653
return (
654654
mod.in_features >= size_limit
655655
and mod.out_features >= size_limit
@@ -682,7 +682,7 @@ def __init__(self, dim: int):
682682
self.lin2 = nn.Linear(4 * dim, dim)
683683

684684
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
685-
module_filter_fn = lambda fqn, mod: fqn not in [
685+
module_filter_fn = lambda mod, fqn: fqn not in [
686686
"0.lin2",
687687
"2.lin1",
688688
]

0 commit comments

Comments
 (0)