15
15
# limitations under the License.
16
16
"""Wrapper around `transformers` models"""
17
17
import re
18
- from typing import Iterable , Optional , Union
18
+ from typing import Iterable , Literal , Optional , Union
19
19
20
20
import torch
21
21
from torch import nn
@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
72
72
ALL_ATTENTION_FUNCTIONS ["vllm" ] = vllm_flash_attention_forward
73
73
74
74
75
+ def log_replacement (name : str , old_module : nn .Module , new_module : nn .Module ):
76
+ logger .debug ("%s: %s -> %s" , name , old_module , new_module )
77
+
78
+
75
79
def replace_linear_class (
76
80
linear : nn .Linear ,
77
- style : str ,
81
+ style : Literal [ "colwise" , "rowwise" ] ,
78
82
quant_config = None ) -> Union [ColumnParallelLinear , RowParallelLinear ]:
79
83
"""
80
- In model configurations, we use a neutral type (string) to specify parallel
81
- styles, here we use it to translate nn.Linear into vllm-style tp Linear.
82
-
83
- Quant config is not supported yet
84
+ Replace nn.Linear with one of vLLM's tensor parallel linear classes.
85
+
86
+ `quant_config` is not yet supported.
87
+ Args:
88
+ linear (nn.Linear): `nn.Linear` to be replaced.
89
+ style (str): Tensor parallel style of the new linear, e.g. "colwise".
90
+ quant_config (QuantConfig): Quantization config for the new linear.
91
+ Returns:
92
+ Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
84
93
"""
85
94
86
95
if not isinstance (style , str ):
@@ -93,7 +102,10 @@ def replace_linear_class(
93
102
}.get (style )
94
103
95
104
if vllm_linear_cls is None :
96
- raise ValueError (f"Unsupported parallel style value: { style } " )
105
+ logger .warning (
106
+ "Unsupported parallel style value: %s. "
107
+ "This layer will not be tensor parallelized." , style )
108
+ return linear
97
109
98
110
class HFCompatibleLinear (vllm_linear_cls ):
99
111
"""
@@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
119
131
super ().__init__ ()
120
132
logger .info ("Using Transformers backend." )
121
133
122
- self .vllm_config = vllm_config
123
134
config = vllm_config .model_config .hf_config
124
135
cache_config = vllm_config .cache_config
125
136
quant_config = vllm_config .quant_config
126
- self . quant_config = quant_config
137
+
127
138
self .config = config
139
+ self .quant_config = quant_config
128
140
self .vocab_size = config .vocab_size
129
141
self .unpadded_vocab_size = config .vocab_size
130
142
131
143
self .model : PreTrainedModel = AutoModel .from_config (
132
144
self .config ,
133
145
attn_implementation = "vllm" ,
134
- torch_dtype = vllm_config .model_config .dtype ,
135
146
trust_remote_code = vllm_config .model_config .trust_remote_code ,
136
147
)
137
148
prefix = self .model .base_model_prefix
138
149
139
150
# MLP modifications
140
- self .tensor_parallelize (self .model )
151
+ self .apply_base_model_tp_plan (self .model )
141
152
142
153
# Attention modifications (assumes 1 attention op per hidden layer)
143
154
tp_size = get_tensor_model_parallel_world_size ()
@@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
170
181
config .vocab_size , logit_scale )
171
182
self .sampler = get_sampler ()
172
183
173
- def log_replacement (self , name : str , old_module : nn .Module ,
174
- new_module : nn . Module ):
175
- logger . debug ( "%s: %s -> %s" , name , old_module , new_module )
176
-
177
- def tensor_parallelize ( self , module : nn . Module , prefix : str = "" ):
184
+ def apply_base_model_tp_plan (self , module : nn .Module , prefix : str = "" ):
185
+ """
186
+ Apply the base model tensor parallelization plan to a module.
187
+ Currently only supports linear layers.
188
+ """
178
189
if (self .config .base_model_tp_plan is None
179
- and self . vllm_config . parallel_config . tensor_parallel_size > 1 ):
190
+ and get_tensor_model_parallel_world_size () > 1 ):
180
191
raise ValueError (
181
192
"Trying to run tensor parallelization but the model does not "
182
193
"support it yet!" )
@@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
189
200
new_module = replace_linear_class (child_module , style ,
190
201
self .quant_config )
191
202
setattr (module , child_name , new_module )
192
- self . log_replacement (qual_name , child_module , new_module )
203
+ log_replacement (qual_name , child_module , new_module )
193
204
else :
194
- self .tensor_parallelize (child_module , prefix = qual_name )
205
+ self .apply_base_model_tp_plan (child_module , prefix = qual_name )
195
206
196
207
def replace_vocab_embed_class (self , module : nn .Module ):
197
208
# Use native set input embeddings
@@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module):
201
212
org_num_embeddings = self .config .vocab_size ,
202
213
quant_config = None ,
203
214
)
204
- self . log_replacement ("input embedding" ,
205
- self . model . get_input_embeddings (), new_module )
215
+ log_replacement ("input embedding" , self . model . get_input_embeddings () ,
216
+ new_module )
206
217
self .model .set_input_embeddings (new_module )
207
218
208
219
def forward (
0 commit comments