6
6
from safetensors import safe_open
7
7
8
8
9
- def _tsize (st , key ):
10
-
11
- tslice = st .get_slice (key )
12
- shape = tslice .get_shape ()
13
- numel = 1
14
- for x in shape : numel *= x
15
- dtype = tslice .get_dtype ()
16
- if dtype == "I32" : return numel * 4
17
- elif dtype == "I16" : return numel * 2
18
- elif dtype == "F16" : return numel * 2
19
- elif dtype == "F32" : return numel * 4
20
- else : raise ValueError ("Unexpected datatype: " + key )
21
-
22
-
23
9
class ExLlamaV2Linear (ExLlamaV2Module ):
24
10
25
11
in_features : int
@@ -29,7 +15,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
29
15
linear : nn .Linear or None = None
30
16
q_handle : int or None = None
31
17
q_tensors : dict or None = None
32
- footprint : int
33
18
34
19
name : str = "Linear"
35
20
@@ -74,48 +59,6 @@ def get_weight(self):
74
59
return self .linear .weight .data
75
60
76
61
77
- def weight_footprint (self ):
78
-
79
- if self .footprint == - 1 :
80
-
81
- # Torch linear layer
82
-
83
- if self .key + ".weight" in self .model .config .tensor_file_map :
84
- filename = self .model .config .tensor_file_map [self .key + ".weight" ]
85
- with safe_open (filename , framework = "pt" , device = "cpu" ) as st :
86
- self .footprint = 0
87
- self .footprint += _tsize (st , self .key + ".weight" )
88
-
89
- # EXL2
90
-
91
- elif self .key + ".q_weight" in self .model .config .tensor_file_map :
92
- filename = self .model .config .tensor_file_map [self .key + ".q_weight" ]
93
- with safe_open (filename , framework = "pt" , device = "cpu" ) as st :
94
- self .footprint = 0
95
- self .footprint += _tsize (st , self .key + ".q_weight" ) + 128
96
- self .footprint += _tsize (st , self .key + ".q_invperm" ) + 128
97
- self .footprint += _tsize (st , self .key + ".q_scale" ) + 128
98
- self .footprint += _tsize (st , self .key + ".q_scale_max" ) + 128
99
- self .footprint += _tsize (st , self .key + ".q_groups" ) + 128
100
- self .footprint += _tsize (st , self .key + ".q_invperm" ) + 128
101
-
102
- # GPTQ
103
-
104
- elif self .key + ".qweight" in self .model .config .tensor_file_map :
105
- filename = self .model .config .tensor_file_map [self .key + ".qweight" ]
106
- with safe_open (filename , framework = "pt" , device = "cpu" ) as st :
107
- self .footprint += _tsize (st , self .key + ".qweight" ) + 128
108
- self .footprint += _tsize (st , self .key + ".qzeros" ) + 128
109
- self .footprint += _tsize (st , self .key + ".scales" ) + 128
110
- if self .key + ".g_idx" in self .model .config .tensor_file_map :
111
- self .footprint += _tsize (st , self .key + ".g_idx" ) + 128
112
-
113
- else :
114
- raise ValueError ("Can't find tensors in model files." )
115
-
116
- return self .footprint
117
-
118
-
119
62
def scratch_space_fixed (self ):
120
63
121
64
return self .temp_dq_size () + \
0 commit comments