Skip to content

Commit d3c56cd

Browse files
Update/Fix Pipeline Mixins and ORT Pipelines (#2021)
* created auto task mappings * added correct auto classes * created auto task mappings * added correct auto classes * added ort/auto diffusion classes * fix ORTPipeline detection * start test refactoring * dynamic dtype * support torch random numbers generator * compact diffusion testing suite * fix * test * test * test * use latent-consistency architecture name instead of lcm * fix * add ort diffusion pipeline tests * added dummy objects * remove duplicate code * update stable diffusion mixin * update latent consistency * update sd for img2img * update latent consistency * update model parts to use frozen dict * update tests and utils * updated all mixins, enabled all tests ; all are passing except some reproducibility and comparaison tests (7 failed, 35 passed) * fix sd xl hidden states * style * support testing without diffusers * remove unnecessary * revert * export vae encoder by returning its latent distribution parameters * fix the modeling to handle distributions * create vae class to minimize changes in pipeline mixins * remove unnecessary tests * style * style * update diffusion models export test * style * fall back for when block_out_channels is not in vae config * remove model parts from optimum.onnxruntime * added .to to model parts * remove custom mixins * style * Update optimum/exporters/onnx/model_configs.py Co-authored-by: Ella Charlaix <[email protected]> * Update optimum/exporters/onnx/model_configs.py * conversion to numpy always work * test adding two new pipelines * remove duplicated tests * match diffusers numpy input * simplify model saving * extend tests and only translate generators * cleanup * reduce parent model usage in model parts * fix * new tiny onnx diffusion model with configs * model_save_path * Update optimum/onnxruntime/modeling_diffusion.py Co-authored-by: Ella Charlaix <[email protected]> * migrate tiny-stable-diffusion-onnx * resolve breaking change and mandatory arguments * overwrite _get_add_time_ids * fix * remove inference calls from loading tests * misc * better compatibility between model parts and parent pipeline * remove subfolder * misc * update * support passing safety checker * dummies * remove the need for ORTPipeline --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent d9754ab commit d3c56cd

21 files changed

+914
-3410
lines changed

optimum/exporters/onnx/model_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
11121112

11131113

11141114
class VaeEncoderOnnxConfig(VisionOnnxConfig):
1115-
ATOL_FOR_VALIDATION = 1e-2
1115+
ATOL_FOR_VALIDATION = 1e-4
11161116
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
11171117
# operator support, available since opset 14
11181118
DEFAULT_ONNX_OPSET = 14
@@ -1132,12 +1132,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
11321132
@property
11331133
def outputs(self) -> Dict[str, Dict[int, str]]:
11341134
return {
1135-
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
1135+
"latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
11361136
}
11371137

11381138

11391139
class VaeDecoderOnnxConfig(VisionOnnxConfig):
1140-
ATOL_FOR_VALIDATION = 1e-3
1140+
ATOL_FOR_VALIDATION = 1e-4
11411141
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
11421142
# operator support, available since opset 14
11431143
DEFAULT_ONNX_OPSET = 14

optimum/exporters/utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,6 @@
4646

4747
from diffusers import (
4848
DiffusionPipeline,
49-
LatentConsistencyModelImg2ImgPipeline,
50-
LatentConsistencyModelPipeline,
51-
StableDiffusionImg2ImgPipeline,
52-
StableDiffusionInpaintPipeline,
53-
StableDiffusionPipeline,
5449
StableDiffusionXLImg2ImgPipeline,
5550
StableDiffusionXLInpaintPipeline,
5651
StableDiffusionXLPipeline,
@@ -92,27 +87,13 @@ def _get_submodels_for_export_diffusion(
9287
Returns the components of a Stable Diffusion model.
9388
"""
9489

95-
is_stable_diffusion = isinstance(
96-
pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline)
97-
)
9890
is_stable_diffusion_xl = isinstance(
9991
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
10092
)
101-
is_latent_consistency_model = isinstance(
102-
pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline)
103-
)
104-
10593
if is_stable_diffusion_xl:
10694
projection_dim = pipeline.text_encoder_2.config.projection_dim
107-
elif is_stable_diffusion:
108-
projection_dim = pipeline.text_encoder.config.projection_dim
109-
elif is_latent_consistency_model:
110-
projection_dim = pipeline.text_encoder.config.projection_dim
11195
else:
112-
raise ValueError(
113-
f"The export of a DiffusionPipeline model with the class name {pipeline.__class__.__name__} is currently not supported in Optimum. "
114-
"Please open an issue or submit a PR to add the support."
115-
)
96+
projection_dim = pipeline.text_encoder.config.projection_dim
11697

11798
models_for_export = {}
11899

@@ -139,7 +120,8 @@ def _get_submodels_for_export_diffusion(
139120
vae_encoder = copy.deepcopy(pipeline.vae)
140121
if not is_torch_greater_or_equal_than_2_1:
141122
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)
142-
vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()}
123+
# we return the distribution parameters to be able to recreate it in the decoder
124+
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
143125
models_for_export["vae_encoder"] = vae_encoder
144126

145127
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600

optimum/onnx/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tu
7171
return src_paths, dst_paths
7272

7373

74+
def _get_model_external_data_paths(model_path: Path) -> List[Path]:
75+
"""
76+
Gets external data paths from the model.
77+
"""
78+
79+
onnx_model = onnx.load(str(model_path), load_external_data=False)
80+
model_tensors = _get_initializer_tensors(onnx_model)
81+
# filter out tensors that are not external data
82+
model_tensors_ext = [
83+
ExternalDataInfo(tensor).location
84+
for tensor in model_tensors
85+
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
86+
]
87+
return [model_path.parent / tensor_name for tensor_name in model_tensors_ext]
88+
89+
7490
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
7591
"""
7692
Checks if the model uses external data.

optimum/onnxruntime/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@
7979
"ORTStableDiffusionInpaintPipeline",
8080
"ORTStableDiffusionXLPipeline",
8181
"ORTStableDiffusionXLImg2ImgPipeline",
82+
"ORTStableDiffusionXLInpaintPipeline",
8283
"ORTLatentConsistencyModelPipeline",
84+
"ORTLatentConsistencyModelImg2ImgPipeline",
8385
"ORTPipelineForImage2Image",
8486
"ORTPipelineForInpainting",
8587
"ORTPipelineForText2Image",
@@ -92,6 +94,8 @@
9294
"ORTStableDiffusionInpaintPipeline",
9395
"ORTStableDiffusionXLPipeline",
9496
"ORTStableDiffusionXLImg2ImgPipeline",
97+
"ORTStableDiffusionXLInpaintPipeline",
98+
"ORTLatentConsistencyModelImg2ImgPipeline",
9599
"ORTLatentConsistencyModelPipeline",
96100
"ORTPipelineForImage2Image",
97101
"ORTPipelineForInpainting",
@@ -148,6 +152,7 @@
148152
except OptionalDependencyNotAvailable:
149153
from ..utils.dummy_diffusers_objects import (
150154
ORTDiffusionPipeline,
155+
ORTLatentConsistencyModelImg2ImgPipeline,
151156
ORTLatentConsistencyModelPipeline,
152157
ORTPipelineForImage2Image,
153158
ORTPipelineForInpainting,
@@ -156,11 +161,13 @@
156161
ORTStableDiffusionInpaintPipeline,
157162
ORTStableDiffusionPipeline,
158163
ORTStableDiffusionXLImg2ImgPipeline,
164+
ORTStableDiffusionXLInpaintPipeline,
159165
ORTStableDiffusionXLPipeline,
160166
)
161167
else:
162168
from .modeling_diffusion import (
163169
ORTDiffusionPipeline,
170+
ORTLatentConsistencyModelImg2ImgPipeline,
164171
ORTLatentConsistencyModelPipeline,
165172
ORTPipelineForImage2Image,
166173
ORTPipelineForInpainting,
@@ -169,6 +176,7 @@
169176
ORTStableDiffusionInpaintPipeline,
170177
ORTStableDiffusionPipeline,
171178
ORTStableDiffusionXLImg2ImgPipeline,
179+
ORTStableDiffusionXLInpaintPipeline,
172180
ORTStableDiffusionXLPipeline,
173181
)
174182
else:

optimum/onnxruntime/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ def dtype(self):
7171

7272
return None
7373

74+
def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
75+
for arg in args:
76+
if isinstance(arg, torch.device):
77+
device = arg
78+
elif isinstance(arg, torch.dtype):
79+
dtype = arg
80+
81+
if device is not None and device != self.device:
82+
raise ValueError(
83+
"Cannot change the device of a model part without changing the device of the parent model. "
84+
"Please use the `to` method of the parent model to change the device."
85+
)
86+
87+
if dtype is not None and dtype != self.dtype:
88+
raise NotImplementedError(
89+
f"Cannot change the dtype of the model from {self.dtype} to {dtype}. "
90+
f"Please export the model with the desired dtype."
91+
)
92+
7493
@abstractmethod
7594
def forward(self, *args, **kwargs):
7695
pass

0 commit comments

Comments
 (0)