Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,7 @@ def export_pytorch(
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

# Check if we need to override certain configuration items
if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES

Expand Down Expand Up @@ -977,6 +971,7 @@ def get_diffusion_models_for_export_ext(
is_flux = pipeline.__class__.__name__.startswith("Flux")
is_sana = pipeline.__class__.__name__.startswith("Sana")
is_ltx_video = pipeline.__class__.__name__.startswith("LTX")
is_qwen_image = pipeline.__class__.__name__.startswith("QwenImage")
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")

Expand All @@ -1002,6 +997,8 @@ def get_diffusion_models_for_export_ext(
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
elif is_ltx_video:
models_for_export = get_ltx_video_models_for_export(pipeline, exporter, int_dtype, float_dtype)
elif is_qwen_image:
models_for_export = get_qwen_image_models_for_export(pipeline, exporter, int_dtype, float_dtype)
else:
raise ValueError(f"Unsupported pipeline type `{pipeline.__class__.__name__}` provided")
return None, models_for_export
Expand Down Expand Up @@ -1414,3 +1411,56 @@ def _get_speecht5_tss_model_for_export(
stateful_per_model = [False, True, False, False]

return export_config, models_for_export, stateful_per_model

def get_qwen_image_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export = {}

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
text_encoder.config.output_hidden_states = True
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="qwen2_5_vl_text",
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
transformer = pipeline.transformer
transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = None
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=transformer,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="qwen-image-transformer-2d",
)
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
transformer_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["transformer"] = (transformer, transformer_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = copy.deepcopy(pipeline.vae)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="qwen-image-decoder",
)
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)

return models_for_export
129 changes: 128 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@
Qwen2VLVisionEmbMergerPatcher,
Qwen3MoeModelPatcher,
QwenModelPatcher,
QwenTransfromerModelPatcher,
QwenVAEPatcher,
SanaTextEncoderModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
Expand Down Expand Up @@ -2127,6 +2129,7 @@ class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):


@register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers")
@register_in_tasks_manager("qwen2_5_vl_text", *["feature-extraction"], library_name="diffusers")
class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
_MODEL_PATCHER = SanaTextEncoderModelPatcher

Expand Down Expand Up @@ -3460,7 +3463,6 @@ def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[
return Qwen2_5_VLVisionEmbMergerPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


@register_in_tasks_manager(
"glm",
*[
Expand Down Expand Up @@ -4269,6 +4271,131 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig):
_MODEL_PATCHER = OVDecoderModelPatcher


class DummyQwenTransformerInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"hidden_states",
"img_shapes",
)

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"] // 4,
height: int = DEFAULT_DUMMY_SHAPES["height"] // 4,
# Reduce img shape by 4 for FLUX to reduce memory usage on conversion
**kwargs,
):
super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
if getattr(normalized_config, "in_channels", None):
self.num_channels = normalized_config.in_channels // 4

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "hidden_states":
shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels * 4]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
if input_name == "img_shapes":
import torch
return torch.tensor((1, self.height // 2, self.width // 2), dtype=DTYPE_MAPPER.pt(int_dtype))

return super().generate(input_name, framework, int_dtype, float_dtype)

class DummyQwenTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
"encoder_hidden_states",
"encoder_hidden_states_mask",
"txt_seq_lens",
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "txt_seq_lens":
import torch

shape = (
[self.batch_size]
)
dtype = DTYPE_MAPPER.pt(int_dtype)
return torch.full(shape, self.sequence_length, dtype=dtype)
if input_name == "encoder_hidden_states_mask":
import torch

shape = (
[self.batch_size, self.sequence_length]
)
dtype = DTYPE_MAPPER.pt(float_dtype)
return torch.full(shape, 1, dtype=dtype)
return super().generate(input_name, framework, int_dtype, float_dtype)

@register_in_tasks_manager("qwen-image-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class QwenTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyQwenTransformerInputGenerator,
DummyQwenTextInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs.pop("sample", None)
common_inputs.pop("pooled_projections", None)
common_inputs["hidden_states"] = {0: "batch_size", 1: "image_sequence_length"}
common_inputs["encoder_hidden_states_mask"] = {0: "batch_size", 1: "text_sequence_length"}
common_inputs["img_shapes"] = {0: "batch_size"}
common_inputs["txt_seq_lens"] = {0: "batch_size"}
if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}
return common_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
# OpenVINO can not handle complex data in this model
return QwenTransfromerModelPatcher(self, model, model_kwargs=model_kwargs)

class QwenVaeDummyInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("sample", "latent_sample")

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
num_frames: int = 1,
**kwargs,
):
super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
self.num_frames = num_frames
self.num_channels = normalized_config.z_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name in ["sample", "latent_sample"]:
return self.random_float_tensor(
[self.batch_size, self.num_channels, self.num_frames, self.height, self.width]
)
return super().generate(input_name, framework, int_dtype, float_dtype)

@register_in_tasks_manager("qwen-image-decoder", *["semantic-segmentation"], library_name="diffusers")
class QwenDecoderOpenVINOConfig(LTXVaeDecoderOpenVINOConfig):
_MODEL_PATCHER = QwenVAEPatcher

# OpenVINO can not support torch.nn.Upsample with nearest-exact interpolation
DUMMY_INPUT_GENERATOR_CLASSES = (QwenVaeDummyInputGenerator,)
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
base_input = {
"latent_sample": {0: "batch_size", 2: "num_frames", 3: "latent_height", 4: "latent_width"},
}
return base_input

@register_in_tasks_manager(
"vision-encoder-decoder",
*[
Expand Down
Loading