Skip to content

Commit 7ef4dfc

Browse files
author
Marcin Kardas
committed
Bring back custom policy to support bias-less OPT-like models
1 parent 86bdbc9 commit 7ef4dfc

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

galai/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,15 @@ def _parallelize(self) -> None:
129129

130130
self._master_port = 13000 + (id(self.model) % 32749)
131131

132+
custom_policies = None
133+
if self.model.config.model_type == "opt" and not self.model.config.enable_bias:
134+
from galai.parallel_policy import OPTDecoderLayerPolicyNoBias
135+
custom_policies = [OPTDecoderLayerPolicyNoBias]
136+
132137
parallelize(
133138
self.model, num_gpus=self.num_gpus, fp16=self.dtype == torch.float16,
134139
master_port=self._master_port,
140+
custom_policies=custom_policies,
135141
)
136142

137143
def _set_tokenizer(self, tokenizer_path: str):

galai/parallel_policy.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from parallelformers.policies.base import Layer, Policy
2+
from parallelformers.utils.dist_utils import AllReduceLinear
3+
4+
from transformers.models.opt.modeling_opt import OPTDecoderLayer
5+
6+
7+
__all__ = ["OPTDecoderLayerPolicyNoBias"]
8+
9+
10+
class OPTDecoderLayerPolicyNoBias(Policy):
11+
@staticmethod
12+
def replace_arguments(config, world_size):
13+
return {
14+
"self_attn.embed_dim": config.hidden_size // world_size,
15+
"self_attn.num_heads": config.num_attention_heads // world_size,
16+
}
17+
18+
@staticmethod
19+
def attn_qkv():
20+
return [
21+
Layer(
22+
weight="self_attn.q_proj.weight",
23+
),
24+
Layer(
25+
weight="self_attn.k_proj.weight",
26+
),
27+
Layer(
28+
weight="self_attn.v_proj.weight",
29+
),
30+
]
31+
32+
@staticmethod
33+
def attn_out():
34+
return [
35+
Layer(
36+
weight="self_attn.out_proj.weight",
37+
replace=AllReduceLinear,
38+
),
39+
]
40+
41+
@staticmethod
42+
def mlp_in():
43+
return [
44+
Layer(
45+
weight="fc1.weight",
46+
),
47+
]
48+
49+
@staticmethod
50+
def mlp_out():
51+
return [
52+
Layer(
53+
weight="fc2.weight",
54+
replace=AllReduceLinear,
55+
),
56+
]
57+
58+
@staticmethod
59+
def original_layer_class():
60+
return OPTDecoderLayer

0 commit comments

Comments
 (0)