Skip to content

Commit 1a6fcad

Browse files
authored
Improve TransformersModel UX (#12785)
1 parent 56534cd commit 1a6fcad

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

Diff for: vllm/model_executor/models/transformers.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
"""Wrapper around `transformers` models"""
1717
import re
18-
from typing import Iterable, Optional, Union
18+
from typing import Iterable, Literal, Optional, Union
1919

2020
import torch
2121
from torch import nn
@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
7272
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
7373

7474

75+
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
76+
logger.debug("%s: %s -> %s", name, old_module, new_module)
77+
78+
7579
def replace_linear_class(
7680
linear: nn.Linear,
77-
style: str,
81+
style: Literal["colwise", "rowwise"],
7882
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
7983
"""
80-
In model configurations, we use a neutral type (string) to specify parallel
81-
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
82-
83-
Quant config is not supported yet
84+
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
85+
86+
`quant_config` is not yet supported.
87+
Args:
88+
linear (nn.Linear): `nn.Linear` to be replaced.
89+
style (str): Tensor parallel style of the new linear, e.g. "colwise".
90+
quant_config (QuantConfig): Quantization config for the new linear.
91+
Returns:
92+
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
8493
"""
8594

8695
if not isinstance(style, str):
@@ -93,7 +102,10 @@ def replace_linear_class(
93102
}.get(style)
94103

95104
if vllm_linear_cls is None:
96-
raise ValueError(f"Unsupported parallel style value: {style}")
105+
logger.warning(
106+
"Unsupported parallel style value: %s. "
107+
"This layer will not be tensor parallelized.", style)
108+
return linear
97109

98110
class HFCompatibleLinear(vllm_linear_cls):
99111
"""
@@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
119131
super().__init__()
120132
logger.info("Using Transformers backend.")
121133

122-
self.vllm_config = vllm_config
123134
config = vllm_config.model_config.hf_config
124135
cache_config = vllm_config.cache_config
125136
quant_config = vllm_config.quant_config
126-
self.quant_config = quant_config
137+
127138
self.config = config
139+
self.quant_config = quant_config
128140
self.vocab_size = config.vocab_size
129141
self.unpadded_vocab_size = config.vocab_size
130142

131143
self.model: PreTrainedModel = AutoModel.from_config(
132144
self.config,
133145
attn_implementation="vllm",
134-
torch_dtype=vllm_config.model_config.dtype,
135146
trust_remote_code=vllm_config.model_config.trust_remote_code,
136147
)
137148
prefix = self.model.base_model_prefix
138149

139150
# MLP modifications
140-
self.tensor_parallelize(self.model)
151+
self.apply_base_model_tp_plan(self.model)
141152

142153
# Attention modifications (assumes 1 attention op per hidden layer)
143154
tp_size = get_tensor_model_parallel_world_size()
@@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
170181
config.vocab_size, logit_scale)
171182
self.sampler = get_sampler()
172183

173-
def log_replacement(self, name: str, old_module: nn.Module,
174-
new_module: nn.Module):
175-
logger.debug("%s: %s -> %s", name, old_module, new_module)
176-
177-
def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
184+
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
185+
"""
186+
Apply the base model tensor parallelization plan to a module.
187+
Currently only supports linear layers.
188+
"""
178189
if (self.config.base_model_tp_plan is None
179-
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
190+
and get_tensor_model_parallel_world_size() > 1):
180191
raise ValueError(
181192
"Trying to run tensor parallelization but the model does not "
182193
"support it yet!")
@@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
189200
new_module = replace_linear_class(child_module, style,
190201
self.quant_config)
191202
setattr(module, child_name, new_module)
192-
self.log_replacement(qual_name, child_module, new_module)
203+
log_replacement(qual_name, child_module, new_module)
193204
else:
194-
self.tensor_parallelize(child_module, prefix=qual_name)
205+
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
195206

196207
def replace_vocab_embed_class(self, module: nn.Module):
197208
# Use native set input embeddings
@@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module):
201212
org_num_embeddings=self.config.vocab_size,
202213
quant_config=None,
203214
)
204-
self.log_replacement("input embedding",
205-
self.model.get_input_embeddings(), new_module)
215+
log_replacement("input embedding", self.model.get_input_embeddings(),
216+
new_module)
206217
self.model.set_input_embeddings(new_module)
207218

208219
def forward(

0 commit comments

Comments
 (0)