-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper3.py
100 lines (76 loc) · 3.42 KB
/
helper3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torch.nn.functional as F
############# From the previous lesson(s) of "Building your own Quantizer"
def w8_a16_forward(weight, input, scales, bias=None):
casted_weights = weight.to(input.dtype)
output = F.linear(input, casted_weights) * scales
if bias is not None:
output = output + bias
return output
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
bias=True, dtype=torch.float32):
super().__init__()
self.register_buffer(
"int8_weights",
torch.randint(
-128, 127, (out_features, in_features), dtype=torch.int8
)
)
self.register_buffer("scales",
torch.randn((out_features), dtype=dtype))
if bias:
self.register_buffer("bias",
torch.randn((1, out_features),
dtype=dtype))
else:
self.bias = None
def quantize(self, weights):
w_fp32 = weights.clone().to(torch.float32)
scales = w_fp32.abs().max(dim=-1).values / 127
scales = scales.to(weights.dtype)
int8_weights = torch.round(weights
/scales.unsqueeze(1)).to(torch.int8)
self.int8_weights = int8_weights
self.scales = scales
def forward(self, input):
return w8_a16_forward(self.int8_weights,
input, self.scales, self.bias)
def replace_linear_with_target_and_quantize(module,
target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias
old_weight = child.weight
new_module = target_class(child.in_features,
child.out_features,
old_bias is not None,
child.weight.dtype)
setattr(module, name, new_module)
getattr(module, name).quantize(old_weight)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target_and_quantize(child,
target_class, module_name_to_exclude)
def replace_linear_with_target(module,
target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias
new_module = target_class(child.in_features,
child.out_features,
old_bias is not None,
child.weight.dtype)
setattr(module, name, new_module)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target(
child, target_class, module_name_to_exclude)
###################################