From c6611be254be9563d045f515d94c20c8c54be8ec Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Feb 2025 16:01:48 -0800 Subject: [PATCH] Remove duplicate definitions of fill_defaults (#1674) --- torchao/dtypes/uintx/uint4_layout.py | 27 ++------------------------- torchao/prototype/dtypes/uint2.py | 11 ++--------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/torchao/dtypes/uintx/uint4_layout.py b/torchao/dtypes/uintx/uint4_layout.py index 204aefcf3c..0b6512640e 100644 --- a/torchao/dtypes/uintx/uint4_layout.py +++ b/torchao/dtypes/uintx/uint4_layout.py @@ -3,6 +3,8 @@ import torch.utils._pytree as pytree from torch.library import Library, impl +from torchao.utils import fill_defaults + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" @@ -13,31 +15,6 @@ def up_size(size): return (*size[:-1], size[-1] * 2) -def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r - - # from # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index 9c14d8ae72..d54e541751 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -4,16 +4,9 @@ import torch import torch._prims_common as utils -UINT2_OPS_TABLE: Dict[Any, Any] = {} - +from torchao.utils import fill_defaults -def fill_defaults(args, n, defaults_tail): - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r +UINT2_OPS_TABLE: Dict[Any, Any] = {} def implements(aten_ops):