Skip to content

Commit e536afc

Browse files
committed
Remove duplicate defenitions of fill_defaults
1 parent 8d14f0e commit e536afc

File tree

2 files changed

+2
-35
lines changed

2 files changed

+2
-35
lines changed

torchao/dtypes/uintx/uint4_layout.py

+1-26
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch._prims_common as utils
33
import torch.utils._pytree as pytree
44
from torch.library import Library, impl
5-
5+
from torchao.utils import fill_defaults
66

77
def down_size(size):
88
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
@@ -13,31 +13,6 @@ def up_size(size):
1313
return (*size[:-1], size[-1] * 2)
1414

1515

16-
def fill_defaults(args, n, defaults_tail):
17-
"""
18-
__torch_dispatch__ doesn't guarantee the number of arguments you are
19-
passed (e.g., defaulted arguments are not passed); but usually it is
20-
convenient to pad out the arguments list with defaults. This function
21-
helps you do that.
22-
Args:
23-
args: the list of positional arguments passed to __torch_dispatch__
24-
n: the number of arguments you are expecting to get
25-
defaults_tail: default values for the arguments, starting from the
26-
end of the list
27-
Example:
28-
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
29-
[1, 2, 3, 4, 5]
30-
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
31-
[1, 2, 3, None, None]]
32-
"""
33-
if n - len(defaults_tail) > len(args):
34-
raise RuntimeError("not enough defaults to fill arguments")
35-
r = list(args)
36-
for i in range(len(args), n):
37-
r.append(defaults_tail[i - n + len(defaults_tail)])
38-
return r
39-
40-
4116
# from
4217
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
4318

torchao/prototype/dtypes/uint2.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@
33

44
import torch
55
import torch._prims_common as utils
6+
from torchao.utils import fill_defaults
67

78
UINT2_OPS_TABLE: Dict[Any, Any] = {}
89

910

10-
def fill_defaults(args, n, defaults_tail):
11-
if n - len(defaults_tail) > len(args):
12-
raise RuntimeError("not enough defaults to fill arguments")
13-
r = list(args)
14-
for i in range(len(args), n):
15-
r.append(defaults_tail[i - n + len(defaults_tail)])
16-
return r
17-
18-
1911
def implements(aten_ops):
2012
def decorator(fn):
2113
for op in aten_ops:

0 commit comments

Comments
 (0)