|
137 | 137 | "from ov_catvton_helper import download_models, convert_pipeline_models, convert_automasker_models\n",
|
138 | 138 | "\n",
|
139 | 139 | "pipeline, mask_processor, automasker = download_models()\n",
|
| 140 | + "vae_scaling_factor = pipeline.vae.config.scaling_factor\n", |
140 | 141 | "convert_pipeline_models(pipeline)\n",
|
141 | 142 | "convert_automasker_models(automasker)"
|
142 | 143 | ]
|
|
181 | 182 | },
|
182 | 183 | {
|
183 | 184 | "cell_type": "code",
|
184 |
| - "execution_count": 22, |
| 185 | + "execution_count": null, |
185 | 186 | "id": "8612d4be-e0cf-4249-881e-5270cc33ef28",
|
186 | 187 | "metadata": {},
|
187 | 188 | "outputs": [],
|
|
197 | 198 | " SCHP_PROCESSOR_LIP,\n",
|
198 | 199 | ")\n",
|
199 | 200 | "\n",
|
200 |
| - "pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH)\n", |
| 201 | + "pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n", |
201 | 202 | "automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)"
|
202 | 203 | ]
|
203 | 204 | },
|
|
239 | 240 | },
|
240 | 241 | {
|
241 | 242 | "cell_type": "code",
|
242 |
| - "execution_count": 8, |
| 243 | + "execution_count": null, |
243 | 244 | "id": "1b307bdd",
|
244 | 245 | "metadata": {},
|
245 | 246 | "outputs": [],
|
246 | 247 | "source": [
|
247 |
| - "optimized_pipe = None\n", |
248 |
| - "optimized_automasker = None\n", |
| 248 | + "is_optimized_pipe_available = False\n", |
249 | 249 | "\n",
|
250 | 250 | "# Fetch skip_kernel_extension module\n",
|
251 | 251 | "r = requests.get(\n",
|
|
309 | 309 | },
|
310 | 310 | {
|
311 | 311 | "cell_type": "code",
|
312 |
| - "execution_count": 10, |
| 312 | + "execution_count": null, |
313 | 313 | "id": "f64b96e4",
|
314 | 314 | "metadata": {},
|
315 | 315 | "outputs": [],
|
316 | 316 | "source": [
|
317 | 317 | "%%skip not $to_quantize.value\n",
|
318 | 318 | "\n",
|
| 319 | + "import gc\n", |
319 | 320 | "import nncf\n",
|
320 | 321 | "from ov_catvton_helper import UNET_PATH\n",
|
321 | 322 | "\n",
|
| 323 | + "# cleanup before quantization to free memory\n", |
| 324 | + "del pipeline\n", |
| 325 | + "del automasker\n", |
| 326 | + "gc.collect()\n", |
| 327 | + "\n", |
| 328 | + "\n", |
322 | 329 | "if not UNET_INT8_PATH.exists():\n",
|
323 | 330 | " unet = core.read_model(UNET_PATH)\n",
|
324 | 331 | " quantized_model = nncf.quantize(\n",
|
|
327 | 334 | " subset_size=subset_size,\n",
|
328 | 335 | " model_type=nncf.ModelType.TRANSFORMER,\n",
|
329 | 336 | " )\n",
|
330 |
| - " ov.save_model(quantized_model, UNET_INT8_PATH)" |
| 337 | + " ov.save_model(quantized_model, UNET_INT8_PATH)\n", |
| 338 | + " del quantized_model\n", |
| 339 | + " gc.collect()" |
331 | 340 | ]
|
332 | 341 | },
|
333 | 342 | {
|
|
352 | 361 | "\n",
|
353 | 362 | "from catvton_quantization_helper import compress_models\n",
|
354 | 363 | "\n",
|
355 |
| - "compress_models(core)" |
356 |
| - ] |
357 |
| - }, |
358 |
| - { |
359 |
| - "cell_type": "code", |
360 |
| - "execution_count": 12, |
361 |
| - "id": "e9c41725", |
362 |
| - "metadata": {}, |
363 |
| - "outputs": [], |
364 |
| - "source": [ |
365 |
| - "%%skip not $to_quantize.value\n", |
366 |
| - "\n", |
367 |
| - "from catvton_quantization_helper import (\n", |
368 |
| - " VAE_ENCODER_INT4_PATH,\n", |
369 |
| - " VAE_DECODER_INT4_PATH,\n", |
370 |
| - " DENSEPOSE_PROCESSOR_INT4_PATH,\n", |
371 |
| - " SCHP_PROCESSOR_ATR_INT4,\n", |
372 |
| - " SCHP_PROCESSOR_LIP_INT4,\n", |
373 |
| - ")\n", |
| 364 | + "compress_models(core)\n", |
374 | 365 | "\n",
|
375 |
| - "optimized_pipe, _, optimized_automasker = download_models()\n", |
376 |
| - "optimized_pipe = get_compiled_pipeline(optimized_pipe, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH)\n", |
377 |
| - "optimized_automasker = get_compiled_automasker(optimized_automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)" |
| 366 | + "is_optimized_pipe_available = True" |
378 | 367 | ]
|
379 | 368 | },
|
380 | 369 | {
|
|
432 | 421 | "source": [
|
433 | 422 | "from ov_catvton_helper import get_pipeline_selection_option\n",
|
434 | 423 | "\n",
|
435 |
| - "use_quantized_models = get_pipeline_selection_option(optimized_pipe)\n", |
| 424 | + "use_quantized_models = get_pipeline_selection_option(is_optimized_pipe_available)\n", |
436 | 425 | "\n",
|
437 | 426 | "use_quantized_models"
|
438 | 427 | ]
|
|
448 | 437 | "source": [
|
449 | 438 | "from gradio_helper import make_demo\n",
|
450 | 439 | "\n",
|
451 |
| - "pipe = optimized_pipe if use_quantized_models.value else pipeline\n", |
452 |
| - "masker = optimized_automasker if use_quantized_models.value else automasker\n", |
| 440 | + "from catvton_quantization_helper import (\n", |
| 441 | + " VAE_ENCODER_INT4_PATH,\n", |
| 442 | + " VAE_DECODER_INT4_PATH,\n", |
| 443 | + " DENSEPOSE_PROCESSOR_INT4_PATH,\n", |
| 444 | + " SCHP_PROCESSOR_ATR_INT4,\n", |
| 445 | + " SCHP_PROCESSOR_LIP_INT4,\n", |
| 446 | + " UNET_INT8_PATH,\n", |
| 447 | + ")\n", |
| 448 | + "\n", |
| 449 | + "pipeline, mask_processor, automasker = download_models()\n", |
| 450 | + "if use_quantized_models.value:\n", |
| 451 | + " pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH, vae_scaling_factor)\n", |
| 452 | + " automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)\n", |
| 453 | + "else:\n", |
| 454 | + " pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n", |
| 455 | + " automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)\n", |
453 | 456 | "\n",
|
454 | 457 | "output_dir = \"output\"\n",
|
455 |
| - "demo = make_demo(pipe, mask_processor, masker, output_dir)\n", |
| 458 | + "demo = make_demo(pipeline, mask_processor, automasker, output_dir)\n", |
456 | 459 | "try:\n",
|
457 | 460 | " demo.launch(debug=True)\n",
|
458 | 461 | "except Exception:\n",
|
|
0 commit comments