Skip to content

Commit 7704a68

Browse files
committed
Fix VRAM usage estimate for linear layer spanning multiple shards
1 parent c5c90a8 commit 7704a68

File tree

2 files changed

+51
-60
lines changed

2 files changed

+51
-60
lines changed

exllamav2/linear.py

-57
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@
66
from safetensors import safe_open
77

88

9-
def _tsize(st, key):
10-
11-
tslice = st.get_slice(key)
12-
shape = tslice.get_shape()
13-
numel = 1
14-
for x in shape: numel *= x
15-
dtype = tslice.get_dtype()
16-
if dtype == "I32": return numel * 4
17-
elif dtype == "I16": return numel * 2
18-
elif dtype == "F16": return numel * 2
19-
elif dtype == "F32": return numel * 4
20-
else: raise ValueError("Unexpected datatype: " + key)
21-
22-
239
class ExLlamaV2Linear(ExLlamaV2Module):
2410

2511
in_features: int
@@ -29,7 +15,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
2915
linear: nn.Linear or None = None
3016
q_handle: int or None = None
3117
q_tensors: dict or None = None
32-
footprint: int
3318

3419
name: str = "Linear"
3520

@@ -74,48 +59,6 @@ def get_weight(self):
7459
return self.linear.weight.data
7560

7661

77-
def weight_footprint(self):
78-
79-
if self.footprint == -1:
80-
81-
# Torch linear layer
82-
83-
if self.key + ".weight" in self.model.config.tensor_file_map:
84-
filename = self.model.config.tensor_file_map[self.key + ".weight"]
85-
with safe_open(filename, framework="pt", device="cpu") as st:
86-
self.footprint = 0
87-
self.footprint += _tsize(st, self.key + ".weight")
88-
89-
# EXL2
90-
91-
elif self.key + ".q_weight" in self.model.config.tensor_file_map:
92-
filename = self.model.config.tensor_file_map[self.key + ".q_weight"]
93-
with safe_open(filename, framework="pt", device="cpu") as st:
94-
self.footprint = 0
95-
self.footprint += _tsize(st, self.key + ".q_weight") + 128
96-
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
97-
self.footprint += _tsize(st, self.key + ".q_scale") + 128
98-
self.footprint += _tsize(st, self.key + ".q_scale_max") + 128
99-
self.footprint += _tsize(st, self.key + ".q_groups") + 128
100-
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
101-
102-
# GPTQ
103-
104-
elif self.key + ".qweight" in self.model.config.tensor_file_map:
105-
filename = self.model.config.tensor_file_map[self.key + ".qweight"]
106-
with safe_open(filename, framework="pt", device="cpu") as st:
107-
self.footprint += _tsize(st, self.key + ".qweight") + 128
108-
self.footprint += _tsize(st, self.key + ".qzeros") + 128
109-
self.footprint += _tsize(st, self.key + ".scales") + 128
110-
if self.key + ".g_idx" in self.model.config.tensor_file_map:
111-
self.footprint += _tsize(st, self.key + ".g_idx") + 128
112-
113-
else:
114-
raise ValueError("Can't find tensors in model files.")
115-
116-
return self.footprint
117-
118-
11962
def scratch_space_fixed(self):
12063

12164
return self.temp_dq_size() + \

exllamav2/module.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,52 @@
33
from exllamav2.config import ExLlamaV2Config
44
from safetensors import safe_open
55

6+
67
def _torch_device(idx):
78
if idx == -1: return "cpu"
89
return f"cuda:{idx}"
910

11+
12+
def _tsize(st, key):
13+
14+
tslice = st.get_slice(key)
15+
shape = tslice.get_shape()
16+
numel = 1
17+
for x in shape: numel *= x
18+
dtype = tslice.get_dtype()
19+
if dtype == "I32": return numel * 4
20+
elif dtype == "I16": return numel * 2
21+
elif dtype == "F16": return numel * 2
22+
elif dtype == "F32": return numel * 4
23+
else: raise ValueError("Unexpected datatype: " + key)
24+
25+
1026
class ExLlamaV2Module:
1127

1228
model = None
1329
config: ExLlamaV2Config
1430
key: str
1531
device_idx: int
32+
footprint: int
1633

1734
def __init__(self, model, key):
1835

1936
self.model = model
2037
self.key = key
38+
self.footprint = -1
2139

2240

2341
def device(self):
2442

2543
return _torch_device(self.device_idx)
2644

2745

28-
def load_multi(self, keys):
46+
def load_multi(self, keys, measure = False):
2947

3048
tensors = {}
3149
submap = {}
3250
submap_i = {}
51+
size = 0
3352

3453
for k in keys:
3554
ck = self.key + "." + k
@@ -44,9 +63,12 @@ def load_multi(self, keys):
4463
for v, ks in submap_i.items():
4564
with safe_open(v, framework="pt", device="cpu") as st:
4665
for k in ks:
47-
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
66+
if measure:
67+
size += _tsize(st, self.key + "." + k)
68+
else:
69+
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
4870

49-
return tensors
71+
return size if measure else tensors
5072

5173

5274
def load_weight(self):
@@ -72,6 +94,32 @@ def load_weight(self):
7294
return nn.Parameter(tensor)
7395

7496

97+
def weight_footprint(self):
98+
99+
if self.footprint == -1:
100+
101+
# EXL2
102+
103+
if self.key + ".q_weight" in self.model.config.tensor_file_map:
104+
self.footprint = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "q_perm"], measure = True)
105+
106+
# GPTQ
107+
108+
elif self.key + ".qweight" in self.model.config.tensor_file_map:
109+
self.footprint = self.load_multi(["qweight", "qzeros", "scales", "g_idx"], measure = True)
110+
111+
# Torch
112+
113+
elif self.key + ".weight" in self.model.config.tensor_file_map:
114+
self.footprint = self.load_multi(["weight"], measure = True)
115+
116+
# Error
117+
118+
else: raise ValueError("Unknown tensor type: " + self.key)
119+
120+
return self.footprint
121+
122+
75123
def set_device_idx(self, idx):
76124

77125
self.device_idx = idx

0 commit comments

Comments
 (0)