2
2
import torch ._prims_common as utils
3
3
import torch .utils ._pytree as pytree
4
4
from torch .library import Library , impl
5
-
5
+ from torchao . utils import fill_defaults
6
6
7
7
def down_size (size ):
8
8
assert size [- 1 ] % 2 == 0 , f"{ size } last dim not divisible by two"
@@ -13,31 +13,6 @@ def up_size(size):
13
13
return (* size [:- 1 ], size [- 1 ] * 2 )
14
14
15
15
16
- def fill_defaults (args , n , defaults_tail ):
17
- """
18
- __torch_dispatch__ doesn't guarantee the number of arguments you are
19
- passed (e.g., defaulted arguments are not passed); but usually it is
20
- convenient to pad out the arguments list with defaults. This function
21
- helps you do that.
22
- Args:
23
- args: the list of positional arguments passed to __torch_dispatch__
24
- n: the number of arguments you are expecting to get
25
- defaults_tail: default values for the arguments, starting from the
26
- end of the list
27
- Example:
28
- >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
29
- [1, 2, 3, 4, 5]
30
- >>> fill_defaults([1, 2, 3], 5, [None, None, None])
31
- [1, 2, 3, None, None]]
32
- """
33
- if n - len (defaults_tail ) > len (args ):
34
- raise RuntimeError ("not enough defaults to fill arguments" )
35
- r = list (args )
36
- for i in range (len (args ), n ):
37
- r .append (defaults_tail [i - n + len (defaults_tail )])
38
- return r
39
-
40
-
41
16
# from
42
17
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
43
18
0 commit comments