1515# limitations under the License.
1616"""Wrapper around `transformers` models"""
1717import re
18- from typing import Iterable , Optional , Union
18+ from typing import Iterable , Literal , Optional , Union
1919
2020import torch
2121from torch import nn
@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
7272ALL_ATTENTION_FUNCTIONS ["vllm" ] = vllm_flash_attention_forward
7373
7474
75+ def log_replacement (name : str , old_module : nn .Module , new_module : nn .Module ):
76+ logger .debug ("%s: %s -> %s" , name , old_module , new_module )
77+
78+
7579def replace_linear_class (
7680 linear : nn .Linear ,
77- style : str ,
81+ style : Literal [ "colwise" , "rowwise" ] ,
7882 quant_config = None ) -> Union [ColumnParallelLinear , RowParallelLinear ]:
7983 """
80- In model configurations, we use a neutral type (string) to specify parallel
81- styles, here we use it to translate nn.Linear into vllm-style tp Linear.
82-
83- Quant config is not supported yet
84+ Replace nn.Linear with one of vLLM's tensor parallel linear classes.
85+
86+ `quant_config` is not yet supported.
87+ Args:
88+ linear (nn.Linear): `nn.Linear` to be replaced.
89+ style (str): Tensor parallel style of the new linear, e.g. "colwise".
90+ quant_config (QuantConfig): Quantization config for the new linear.
91+ Returns:
92+ Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
8493 """
8594
8695 if not isinstance (style , str ):
@@ -93,7 +102,10 @@ def replace_linear_class(
93102 }.get (style )
94103
95104 if vllm_linear_cls is None :
96- raise ValueError (f"Unsupported parallel style value: { style } " )
105+ logger .warning (
106+ "Unsupported parallel style value: %s. "
107+ "This layer will not be tensor parallelized." , style )
108+ return linear
97109
98110 class HFCompatibleLinear (vllm_linear_cls ):
99111 """
@@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
119131 super ().__init__ ()
120132 logger .info ("Using Transformers backend." )
121133
122- self .vllm_config = vllm_config
123134 config = vllm_config .model_config .hf_config
124135 cache_config = vllm_config .cache_config
125136 quant_config = vllm_config .quant_config
126- self . quant_config = quant_config
137+
127138 self .config = config
139+ self .quant_config = quant_config
128140 self .vocab_size = config .vocab_size
129141 self .unpadded_vocab_size = config .vocab_size
130142
131143 self .model : PreTrainedModel = AutoModel .from_config (
132144 self .config ,
133145 attn_implementation = "vllm" ,
134- torch_dtype = vllm_config .model_config .dtype ,
135146 trust_remote_code = vllm_config .model_config .trust_remote_code ,
136147 )
137148 prefix = self .model .base_model_prefix
138149
139150 # MLP modifications
140- self .tensor_parallelize (self .model )
151+ self .apply_base_model_tp_plan (self .model )
141152
142153 # Attention modifications (assumes 1 attention op per hidden layer)
143154 tp_size = get_tensor_model_parallel_world_size ()
@@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
170181 config .vocab_size , logit_scale )
171182 self .sampler = get_sampler ()
172183
173- def log_replacement (self , name : str , old_module : nn .Module ,
174- new_module : nn . Module ):
175- logger . debug ( "%s: %s -> %s" , name , old_module , new_module )
176-
177- def tensor_parallelize ( self , module : nn . Module , prefix : str = "" ):
184+ def apply_base_model_tp_plan (self , module : nn .Module , prefix : str = "" ):
185+ """
186+ Apply the base model tensor parallelization plan to a module.
187+ Currently only supports linear layers.
188+ """
178189 if (self .config .base_model_tp_plan is None
179- and self . vllm_config . parallel_config . tensor_parallel_size > 1 ):
190+ and get_tensor_model_parallel_world_size () > 1 ):
180191 raise ValueError (
181192 "Trying to run tensor parallelization but the model does not "
182193 "support it yet!" )
@@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
189200 new_module = replace_linear_class (child_module , style ,
190201 self .quant_config )
191202 setattr (module , child_name , new_module )
192- self . log_replacement (qual_name , child_module , new_module )
203+ log_replacement (qual_name , child_module , new_module )
193204 else :
194- self .tensor_parallelize (child_module , prefix = qual_name )
205+ self .apply_base_model_tp_plan (child_module , prefix = qual_name )
195206
196207 def replace_vocab_embed_class (self , module : nn .Module ):
197208 # Use native set input embeddings
@@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module):
201212 org_num_embeddings = self .config .vocab_size ,
202213 quant_config = None ,
203214 )
204- self . log_replacement ("input embedding" ,
205- self . model . get_input_embeddings (), new_module )
215+ log_replacement ("input embedding" , self . model . get_input_embeddings () ,
216+ new_module )
206217 self .model .set_input_embeddings (new_module )
207218
208219 def forward (
0 commit comments