Skip to content

Commit c6611be

Browse files
authored
Remove duplicate definitions of fill_defaults (#1674)
1 parent bc1530b commit c6611be

File tree

2 files changed

+4
-34
lines changed

2 files changed

+4
-34
lines changed

torchao/dtypes/uintx/uint4_layout.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch.utils._pytree as pytree
44
from torch.library import Library, impl
55

6+
from torchao.utils import fill_defaults
7+
68

79
def down_size(size):
810
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
@@ -13,31 +15,6 @@ def up_size(size):
1315
return (*size[:-1], size[-1] * 2)
1416

1517

16-
def fill_defaults(args, n, defaults_tail):
17-
"""
18-
__torch_dispatch__ doesn't guarantee the number of arguments you are
19-
passed (e.g., defaulted arguments are not passed); but usually it is
20-
convenient to pad out the arguments list with defaults. This function
21-
helps you do that.
22-
Args:
23-
args: the list of positional arguments passed to __torch_dispatch__
24-
n: the number of arguments you are expecting to get
25-
defaults_tail: default values for the arguments, starting from the
26-
end of the list
27-
Example:
28-
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
29-
[1, 2, 3, 4, 5]
30-
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
31-
[1, 2, 3, None, None]]
32-
"""
33-
if n - len(defaults_tail) > len(args):
34-
raise RuntimeError("not enough defaults to fill arguments")
35-
r = list(args)
36-
for i in range(len(args), n):
37-
r.append(defaults_tail[i - n + len(defaults_tail)])
38-
return r
39-
40-
4118
# from
4219
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
4320

torchao/prototype/dtypes/uint2.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,9 @@
44
import torch
55
import torch._prims_common as utils
66

7-
UINT2_OPS_TABLE: Dict[Any, Any] = {}
8-
7+
from torchao.utils import fill_defaults
98

10-
def fill_defaults(args, n, defaults_tail):
11-
if n - len(defaults_tail) > len(args):
12-
raise RuntimeError("not enough defaults to fill arguments")
13-
r = list(args)
14-
for i in range(len(args), n):
15-
r.append(defaults_tail[i - n + len(defaults_tail)])
16-
return r
9+
UINT2_OPS_TABLE: Dict[Any, Any] = {}
1710

1811

1912
def implements(aten_ops):

0 commit comments

Comments
 (0)