|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | """Model specific ONNX configurations."""
|
| 16 | + |
16 | 17 | import math
|
17 | 18 | import random
|
18 | 19 | import warnings
|
@@ -1337,45 +1338,51 @@ class VaeEncoderOnnxConfig(VisionOnnxConfig):
|
1337 | 1338 | DEFAULT_ONNX_OPSET = 14
|
1338 | 1339 |
|
1339 | 1340 | NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
|
1340 |
| - num_channels="in_channels", |
1341 |
| - image_size="sample_size", |
1342 |
| - allow_new=True, |
| 1341 | + num_channels="in_channels", image_size="sample_size", allow_new=True |
1343 | 1342 | )
|
1344 | 1343 |
|
1345 | 1344 | @property
|
1346 | 1345 | def inputs(self) -> Dict[str, Dict[int, str]]:
|
1347 | 1346 | return {
|
1348 |
| - "sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 1347 | + "sample": {0: "batch_size", 2: "sample_height", 3: "sample_width"}, |
1349 | 1348 | }
|
1350 | 1349 |
|
1351 | 1350 | @property
|
1352 | 1351 | def outputs(self) -> Dict[str, Dict[int, str]]:
|
| 1352 | + down_sampling_factor = 2 ** (len(self._normalized_config.down_block_types) - 1) |
1353 | 1353 | return {
|
1354 |
| - "latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, |
| 1354 | + "latent_parameters": { |
| 1355 | + 0: "batch_size", |
| 1356 | + 2: f"sample_height / {down_sampling_factor}", |
| 1357 | + 3: f"sample_width / {down_sampling_factor}", |
| 1358 | + }, |
1355 | 1359 | }
|
1356 | 1360 |
|
1357 | 1361 |
|
1358 | 1362 | class VaeDecoderOnnxConfig(VisionOnnxConfig):
|
1359 |
| - ATOL_FOR_VALIDATION = 1e-4 |
| 1363 | + ATOL_FOR_VALIDATION = 3e-4 |
1360 | 1364 | # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
|
1361 | 1365 | # operator support, available since opset 14
|
1362 | 1366 | DEFAULT_ONNX_OPSET = 14
|
1363 | 1367 |
|
1364 |
| - NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( |
1365 |
| - num_channels="latent_channels", |
1366 |
| - allow_new=True, |
1367 |
| - ) |
| 1368 | + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(num_channels="latent_channels", allow_new=True) |
1368 | 1369 |
|
1369 | 1370 | @property
|
1370 | 1371 | def inputs(self) -> Dict[str, Dict[int, str]]:
|
1371 | 1372 | return {
|
1372 |
| - "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, |
| 1373 | + "latent_sample": {0: "batch_size", 2: "latent_height", 3: "latent_width"}, |
1373 | 1374 | }
|
1374 | 1375 |
|
1375 | 1376 | @property
|
1376 | 1377 | def outputs(self) -> Dict[str, Dict[int, str]]:
|
| 1378 | + upsampling_factor = 2 ** (len(self._normalized_config.up_block_types) - 1) |
| 1379 | + |
1377 | 1380 | return {
|
1378 |
| - "sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 1381 | + "sample": { |
| 1382 | + 0: "batch_size", |
| 1383 | + 2: f"latent_height * {upsampling_factor}", |
| 1384 | + 3: f"latent_width * {upsampling_factor}", |
| 1385 | + }, |
1379 | 1386 | }
|
1380 | 1387 |
|
1381 | 1388 |
|
@@ -1815,9 +1822,17 @@ class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
1815 | 1822 | DEFAULT_ONNX_OPSET = 14
|
1816 | 1823 |
|
1817 | 1824 | VARIANTS = {
|
1818 |
| - "text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.\n\t\t* build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.", |
| 1825 | + "text-conditional-with-past": """Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). |
| 1826 | + This uses the decoder KV cache. The following subcomponents are exported: |
| 1827 | + * text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457. |
| 1828 | + * encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480. |
| 1829 | + * decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead). |
| 1830 | + * decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead). |
| 1831 | + * decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False. |
| 1832 | + * build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.""", |
1819 | 1833 | }
|
1820 |
| - # TODO: support audio-prompted generation (- audio_encoder_encode.onnx: corresponds to the audio encoder part in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.\n\t) |
| 1834 | + # TODO: support audio-prompted generation (audio_encoder_encode.onnx: corresponds to the audio encoder part |
| 1835 | + # in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.) |
1821 | 1836 | # With that, we have full Encodec support.
|
1822 | 1837 | DEFAULT_VARIANT = "text-conditional-with-past"
|
1823 | 1838 |
|
|
0 commit comments