Skip to content

Commit

Permalink
Remove duplicate definitions of fill_defaults (#1674)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Feb 6, 2025
1 parent bc1530b commit c6611be
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 34 deletions.
27 changes: 2 additions & 25 deletions torchao/dtypes/uintx/uint4_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.utils._pytree as pytree
from torch.library import Library, impl

from torchao.utils import fill_defaults


def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
Expand All @@ -13,31 +15,6 @@ def up_size(size):
return (*size[:-1], size[-1] * 2)


def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


# from
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233

Expand Down
11 changes: 2 additions & 9 deletions torchao/prototype/dtypes/uint2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@
import torch
import torch._prims_common as utils

UINT2_OPS_TABLE: Dict[Any, Any] = {}

from torchao.utils import fill_defaults

def fill_defaults(args, n, defaults_tail):
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r
UINT2_OPS_TABLE: Dict[Any, Any] = {}


def implements(aten_ops):
Expand Down

0 comments on commit c6611be

Please sign in to comment.