Skip to content

Commit ca44e3e

Browse files
committed
reduce VRAM usage, instead of increasing main RAM usage
1 parent 56b4ea9 commit ca44e3e

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
137137

138138
## Change History
139139

140+
### Oct 27, 2024 / 2024-10-27:
141+
142+
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
143+
- This will be included in the next release.
144+
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
145+
- これは次回リリースに含まれます。
146+
140147
### Oct 26, 2024 / 2024-10-26:
141148

142149
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.

networks/svd_merge_lora.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
301301
# make original weight if not exist
302302
if lora_module_name not in merged_sd:
303303
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
304-
if device:
305-
weight = weight.to(device)
306304
else:
307305
weight = merged_sd[lora_module_name]
306+
if device:
307+
weight = weight.to(device)
308308

309309
# merge to weight
310310
if device:
@@ -336,13 +336,16 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
336336
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
337337
weight = weight + ratio * conved * scale
338338

339-
merged_sd[lora_module_name] = weight
339+
merged_sd[lora_module_name] = weight.to("cpu")
340340

341341
# extract from merged weights
342342
logger.info("extract new lora...")
343343
merged_lora_sd = {}
344344
with torch.no_grad():
345345
for lora_module_name, mat in tqdm(list(merged_sd.items())):
346+
if device:
347+
mat = mat.to(device)
348+
346349
conv2d = len(mat.size()) == 4
347350
kernel_size = None if not conv2d else mat.size()[2:4]
348351
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -381,7 +384,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
381384

382385
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
383386
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
384-
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
387+
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
385388

386389
# build minimum metadata
387390
dims = f"{new_rank}"

0 commit comments

Comments
 (0)