3
3
import torch .utils ._pytree as pytree
4
4
from torch .library import Library , impl
5
5
6
+ from torchao .utils import fill_defaults
7
+
6
8
7
9
def down_size (size ):
8
10
assert size [- 1 ] % 2 == 0 , f"{ size } last dim not divisible by two"
@@ -13,31 +15,6 @@ def up_size(size):
13
15
return (* size [:- 1 ], size [- 1 ] * 2 )
14
16
15
17
16
- def fill_defaults (args , n , defaults_tail ):
17
- """
18
- __torch_dispatch__ doesn't guarantee the number of arguments you are
19
- passed (e.g., defaulted arguments are not passed); but usually it is
20
- convenient to pad out the arguments list with defaults. This function
21
- helps you do that.
22
- Args:
23
- args: the list of positional arguments passed to __torch_dispatch__
24
- n: the number of arguments you are expecting to get
25
- defaults_tail: default values for the arguments, starting from the
26
- end of the list
27
- Example:
28
- >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
29
- [1, 2, 3, 4, 5]
30
- >>> fill_defaults([1, 2, 3], 5, [None, None, None])
31
- [1, 2, 3, None, None]]
32
- """
33
- if n - len (defaults_tail ) > len (args ):
34
- raise RuntimeError ("not enough defaults to fill arguments" )
35
- r = list (args )
36
- for i in range (len (args ), n ):
37
- r .append (defaults_tail [i - n + len (defaults_tail )])
38
- return r
39
-
40
-
41
18
# from
42
19
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
43
20
0 commit comments