Skip to content

Commit 3dc041d

Browse files
Cleanup models after conversion (#2558)
CVS-157654
1 parent e242c65 commit 3dc041d

File tree

3 files changed

+90
-66
lines changed

3 files changed

+90
-66
lines changed

notebooks/catvton/catvton.ipynb

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
"from ov_catvton_helper import download_models, convert_pipeline_models, convert_automasker_models\n",
138138
"\n",
139139
"pipeline, mask_processor, automasker = download_models()\n",
140+
"vae_scaling_factor = pipeline.vae.config.scaling_factor\n",
140141
"convert_pipeline_models(pipeline)\n",
141142
"convert_automasker_models(automasker)"
142143
]
@@ -181,7 +182,7 @@
181182
},
182183
{
183184
"cell_type": "code",
184-
"execution_count": 22,
185+
"execution_count": null,
185186
"id": "8612d4be-e0cf-4249-881e-5270cc33ef28",
186187
"metadata": {},
187188
"outputs": [],
@@ -197,7 +198,7 @@
197198
" SCHP_PROCESSOR_LIP,\n",
198199
")\n",
199200
"\n",
200-
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH)\n",
201+
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n",
201202
"automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)"
202203
]
203204
},
@@ -239,13 +240,12 @@
239240
},
240241
{
241242
"cell_type": "code",
242-
"execution_count": 8,
243+
"execution_count": null,
243244
"id": "1b307bdd",
244245
"metadata": {},
245246
"outputs": [],
246247
"source": [
247-
"optimized_pipe = None\n",
248-
"optimized_automasker = None\n",
248+
"is_optimized_pipe_available = False\n",
249249
"\n",
250250
"# Fetch skip_kernel_extension module\n",
251251
"r = requests.get(\n",
@@ -309,16 +309,23 @@
309309
},
310310
{
311311
"cell_type": "code",
312-
"execution_count": 10,
312+
"execution_count": null,
313313
"id": "f64b96e4",
314314
"metadata": {},
315315
"outputs": [],
316316
"source": [
317317
"%%skip not $to_quantize.value\n",
318318
"\n",
319+
"import gc\n",
319320
"import nncf\n",
320321
"from ov_catvton_helper import UNET_PATH\n",
321322
"\n",
323+
"# cleanup before quantization to free memory\n",
324+
"del pipeline\n",
325+
"del automasker\n",
326+
"gc.collect()\n",
327+
"\n",
328+
"\n",
322329
"if not UNET_INT8_PATH.exists():\n",
323330
" unet = core.read_model(UNET_PATH)\n",
324331
" quantized_model = nncf.quantize(\n",
@@ -327,7 +334,9 @@
327334
" subset_size=subset_size,\n",
328335
" model_type=nncf.ModelType.TRANSFORMER,\n",
329336
" )\n",
330-
" ov.save_model(quantized_model, UNET_INT8_PATH)"
337+
" ov.save_model(quantized_model, UNET_INT8_PATH)\n",
338+
" del quantized_model\n",
339+
" gc.collect()"
331340
]
332341
},
333342
{
@@ -352,29 +361,9 @@
352361
"\n",
353362
"from catvton_quantization_helper import compress_models\n",
354363
"\n",
355-
"compress_models(core)"
356-
]
357-
},
358-
{
359-
"cell_type": "code",
360-
"execution_count": 12,
361-
"id": "e9c41725",
362-
"metadata": {},
363-
"outputs": [],
364-
"source": [
365-
"%%skip not $to_quantize.value\n",
366-
"\n",
367-
"from catvton_quantization_helper import (\n",
368-
" VAE_ENCODER_INT4_PATH,\n",
369-
" VAE_DECODER_INT4_PATH,\n",
370-
" DENSEPOSE_PROCESSOR_INT4_PATH,\n",
371-
" SCHP_PROCESSOR_ATR_INT4,\n",
372-
" SCHP_PROCESSOR_LIP_INT4,\n",
373-
")\n",
364+
"compress_models(core)\n",
374365
"\n",
375-
"optimized_pipe, _, optimized_automasker = download_models()\n",
376-
"optimized_pipe = get_compiled_pipeline(optimized_pipe, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH)\n",
377-
"optimized_automasker = get_compiled_automasker(optimized_automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)"
366+
"is_optimized_pipe_available = True"
378367
]
379368
},
380369
{
@@ -432,7 +421,7 @@
432421
"source": [
433422
"from ov_catvton_helper import get_pipeline_selection_option\n",
434423
"\n",
435-
"use_quantized_models = get_pipeline_selection_option(optimized_pipe)\n",
424+
"use_quantized_models = get_pipeline_selection_option(is_optimized_pipe_available)\n",
436425
"\n",
437426
"use_quantized_models"
438427
]
@@ -448,11 +437,25 @@
448437
"source": [
449438
"from gradio_helper import make_demo\n",
450439
"\n",
451-
"pipe = optimized_pipe if use_quantized_models.value else pipeline\n",
452-
"masker = optimized_automasker if use_quantized_models.value else automasker\n",
440+
"from catvton_quantization_helper import (\n",
441+
" VAE_ENCODER_INT4_PATH,\n",
442+
" VAE_DECODER_INT4_PATH,\n",
443+
" DENSEPOSE_PROCESSOR_INT4_PATH,\n",
444+
" SCHP_PROCESSOR_ATR_INT4,\n",
445+
" SCHP_PROCESSOR_LIP_INT4,\n",
446+
" UNET_INT8_PATH,\n",
447+
")\n",
448+
"\n",
449+
"pipeline, mask_processor, automasker = download_models()\n",
450+
"if use_quantized_models.value:\n",
451+
" pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH, vae_scaling_factor)\n",
452+
" automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)\n",
453+
"else:\n",
454+
" pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n",
455+
" automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)\n",
453456
"\n",
454457
"output_dir = \"output\"\n",
455-
"demo = make_demo(pipe, mask_processor, masker, output_dir)\n",
458+
"demo = make_demo(pipeline, mask_processor, automasker, output_dir)\n",
456459
"try:\n",
457460
" demo.launch(debug=True)\n",
458461
"except Exception:\n",

notebooks/catvton/catvton_quantization_helper.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import Any, List
2-
import torch
3-
import nncf
2+
from pathlib import Path
3+
import pickle
44

55
from tqdm.notebook import tqdm
66
from transformers import set_seed
77
import numpy as np
88
import openvino as ov
99
from PIL import Image
10+
import torch
11+
import nncf
1012

1113
from ov_catvton_helper import (
1214
MODEL_DIR,
@@ -49,34 +51,45 @@ def __call__(self, *args, **kwargs):
4951

5052

5153
def collect_calibration_data(pipeline, automasker, mask_processor, dataset, subset_size):
52-
original_unet = pipeline.unet.unet
53-
pipeline.unet.unet = CompiledModelDecorator(original_unet)
54-
55-
calibration_dataset = []
56-
pbar = tqdm(total=subset_size, desc="Collecting calibration dataset")
57-
for data in dataset:
58-
person_image_path, cloth_image_path = data
59-
person_image = Image.open(person_image_path)
60-
cloth_image = Image.open(cloth_image_path)
61-
cloth_type = "upper" if "upper" in person_image_path.as_posix() else "overall"
62-
mask = automasker(person_image, cloth_type)["mask"]
63-
mask = mask_processor.blur(mask, blur_factor=9)
64-
65-
pipeline(
66-
image=person_image,
67-
condition_image=cloth_image,
68-
mask=mask,
69-
num_inference_steps=NUM_INFERENCE_STEPS,
70-
guidance_scale=GUIDANCE_SCALE,
71-
generator=GENERATOR,
72-
)
73-
collected_subset_size = len(pipeline.unet.unet.data_cache)
74-
pbar.update(NUM_INFERENCE_STEPS)
75-
if collected_subset_size >= subset_size:
76-
break
54+
calibration_dataset_filepath = Path("calibration_data") / f"{subset_size}.pkl"
55+
calibration_dataset_filepath.parent.mkdir(exist_ok=True, parents=True)
56+
57+
if not calibration_dataset_filepath.exists():
58+
original_unet = pipeline.unet.unet
59+
pipeline.unet.unet = CompiledModelDecorator(original_unet)
60+
61+
calibration_dataset = []
62+
pbar = tqdm(total=subset_size, desc="Collecting calibration dataset")
63+
for data in dataset:
64+
person_image_path, cloth_image_path = data
65+
person_image = Image.open(person_image_path)
66+
cloth_image = Image.open(cloth_image_path)
67+
cloth_type = "upper" if "upper" in person_image_path.as_posix() else "overall"
68+
mask = automasker(person_image, cloth_type)["mask"]
69+
mask = mask_processor.blur(mask, blur_factor=9)
70+
71+
pipeline(
72+
image=person_image,
73+
condition_image=cloth_image,
74+
mask=mask,
75+
num_inference_steps=NUM_INFERENCE_STEPS,
76+
guidance_scale=GUIDANCE_SCALE,
77+
generator=GENERATOR,
78+
)
79+
collected_subset_size = len(pipeline.unet.unet.data_cache)
80+
pbar.update(NUM_INFERENCE_STEPS)
81+
if collected_subset_size >= subset_size:
82+
break
83+
84+
calibration_dataset = pipeline.unet.unet.data_cache
85+
pipeline.unet.unet = original_unet
86+
87+
with open(calibration_dataset_filepath, "wb") as f:
88+
pickle.dump(calibration_dataset, f)
89+
else:
90+
with open(calibration_dataset_filepath, "rb") as f:
91+
calibration_dataset = pickle.load(f)
7792

78-
calibration_dataset = pipeline.unet.unet.data_cache
79-
pipeline.unet.unet = original_unet
8093
return calibration_dataset
8194

8295

notebooks/catvton/ov_catvton_helper.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import os
23
from collections import namedtuple
34
from pathlib import Path
@@ -93,13 +94,16 @@ def download_models():
9394
def convert_pipeline_models(pipeline):
9495
convert(VaeEncoder(pipeline.vae), VAE_ENCODER_PATH, torch.zeros(1, 3, 1024, 768))
9596
convert(VaeDecoder(pipeline.vae), VAE_DECODER_PATH, torch.zeros(1, 4, 128, 96))
97+
del pipeline.vae
9698

9799
inpainting_latent_model_input = torch.zeros(2, 9, 256, 96)
98100
timestep = torch.tensor(0)
99101
encoder_hidden_states = torch.zeros(2, 1, 768)
100102
example_input = (inpainting_latent_model_input, timestep, encoder_hidden_states)
101103

102104
convert(UNetWrapper(pipeline.unet), UNET_PATH, example_input)
105+
del pipeline.unet
106+
gc.collect()
103107

104108

105109
def convert_automasker_models(automasker):
@@ -115,19 +119,23 @@ def inference(model, inputs):
115119
traceable_model = TracingAdapter(automasker.densepose_processor.predictor.model, tracing_input, inference)
116120

117121
convert(traceable_model, DENSEPOSE_PROCESSOR_PATH, tracing_input[0]["image"])
122+
del automasker.densepose_processor.predictor.model
118123

119124
convert(automasker.schp_processor_atr.model, SCHP_PROCESSOR_ATR, torch.rand([1, 3, 512, 512], dtype=torch.float32))
120125
convert(automasker.schp_processor_lip.model, SCHP_PROCESSOR_LIP, torch.rand([1, 3, 473, 473], dtype=torch.float32))
126+
del automasker.schp_processor_atr.model
127+
del automasker.schp_processor_lip.model
128+
gc.collect()
121129

122130

123131
class VAEWrapper(torch.nn.Module):
124-
def __init__(self, vae_encoder, vae_decoder, config):
132+
def __init__(self, vae_encoder, vae_decoder, scaling_factor):
125133
super().__init__()
126134
self.vae_enocder = vae_encoder
127135
self.vae_decoder = vae_decoder
128136
self.device = "cpu"
129137
self.dtype = torch.float32
130-
self.config = config
138+
self.config = namedtuple("VAEConfig", ["scaling_factor"])(scaling_factor)
131139

132140
def encode(self, pixel_values):
133141
ov_outputs = self.vae_enocder(pixel_values).to_dict()
@@ -202,12 +210,12 @@ def forward(self, image):
202210
return torch.from_numpy(outputs[0])
203211

204212

205-
def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path):
213+
def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path, vae_scaling_factor):
206214
compiled_unet = core.compile_model(unet_path, device.value)
207215
compiled_vae_encoder = core.compile_model(vae_encoder_path, device.value)
208216
compiled_vae_decoder = core.compile_model(vae_decoder_path, device.value)
209217

210-
pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, pipeline.vae.config)
218+
pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, vae_scaling_factor)
211219
pipeline.unet = ConvUnetWrapper(compiled_unet)
212220

213221
return pipeline

0 commit comments

Comments
 (0)