Skip to content

Commit c2213f8

Browse files
committed
test: add unit tests for --aggressive-offload (12 tests)
1 parent 2c8db00 commit c2213f8

File tree

4 files changed

+327
-7
lines changed

4 files changed

+327
-7
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def from_string(cls, value: str):
158158
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
159159

160160
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
161-
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Frees ~18GB during sampling by unloading text encoders after encoding. Trade-off: ~10s reload penalty per subsequent generation.")
161+
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Moves all models larger than 1 GB to a virtual (meta) device between runs, preventing swap pressure on disk. Small models like the VAE are preserved. Trade-off: models are reloaded from disk on subsequent generations.")
162162
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
163163

164164
class PerformanceFeature(enum.Enum):

comfy/model_management.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -681,11 +681,12 @@ def offloaded_memory(loaded_models, device):
681681
WINDOWS = any(platform.win32_ver())
682682

683683
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
684-
if cpu_state == CPUState.MPS:
685-
# macOS with Apple Silicon: shared memory means OS needs more headroom.
686-
# Reserve 4 GB for macOS + system services to prevent swap thrashing.
684+
if cpu_state == CPUState.MPS and AGGRESSIVE_OFFLOAD:
685+
# macOS with Apple Silicon + aggressive offload: shared memory means OS
686+
# needs more headroom. Reserve 4 GB for macOS + system services to
687+
# prevent swap thrashing during model destruction/reload cycles.
687688
EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
688-
logging.info("MPS detected: reserving 4 GB for macOS system overhead")
689+
logging.info("MPS detected with --aggressive-offload: reserving 4 GB for macOS system overhead")
689690
elif WINDOWS:
690691
import comfy.windows
691692
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
@@ -749,9 +750,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
749750
# Aggressive offload for Apple Silicon: force-unload unused models
750751
# regardless of free memory, since CPU RAM == GPU VRAM.
751752
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
752-
if not current_loaded_models[i].currently_used:
753+
model_ref = current_loaded_models[i].model
754+
if model_ref is not None and not current_loaded_models[i].currently_used:
753755
memory_to_free = 1e32 # Force unload
754-
model_name = current_loaded_models[i].model.model.__class__.__name__
756+
inner = getattr(model_ref, "model", None)
757+
model_name = inner.__class__.__name__ if inner is not None else "unknown"
755758
model_size_mb = current_loaded_models[i].model_memory() / (1024 * 1024)
756759
logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM")
757760

comfy_execution/caching.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def all_node_ids(self):
428428
def clean_unused(self):
429429
pass
430430

431+
def clear_all(self):
432+
"""No-op: null backend has nothing to invalidate."""
433+
pass
434+
431435
def poll(self, **kwargs):
432436
pass
433437

@@ -461,6 +465,13 @@ async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
461465
for node_id in node_ids:
462466
self._mark_used(node_id)
463467

468+
def clear_all(self):
469+
"""Drop all cached outputs and reset LRU bookkeeping."""
470+
super().clear_all()
471+
self.used_generation.clear()
472+
self.children.clear()
473+
self.min_generation = 0
474+
464475
def clean_unused(self):
465476
while len(self.cache) > self.max_size and self.min_generation < self.generation:
466477
self.min_generation += 1
@@ -519,6 +530,11 @@ def __init__(self, key_class, enable_providers=False):
519530
super().__init__(key_class, 0, enable_providers=enable_providers)
520531
self.timestamps = {}
521532

533+
def clear_all(self):
534+
"""Drop all cached outputs and reset RAM-pressure bookkeeping."""
535+
super().clear_all()
536+
self.timestamps.clear()
537+
522538
def clean_unused(self):
523539
self._clean_subcaches()
524540

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""Tests for the aggressive-offload memory management feature.
2+
3+
These tests validate the Apple Silicon (MPS) memory optimisation path without
4+
requiring a GPU or actual model weights. Every test mocks the relevant model
5+
and cache structures so the suite can run in CI on any platform.
6+
"""
7+
8+
import pytest
9+
import types
10+
import torch
11+
import torch.nn as nn
12+
13+
# ---------------------------------------------------------------------------
14+
# Fixtures & helpers
15+
# ---------------------------------------------------------------------------
16+
17+
class FakeLinearModel(nn.Module):
18+
"""Minimal nn.Module whose parameters consume measurable memory."""
19+
20+
def __init__(self, size_mb: float = 2.0):
21+
super().__init__()
22+
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
23+
n = int(size_mb * 1024 * 1024 / 4)
24+
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
25+
26+
27+
class FakeModelPatcher:
28+
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
29+
30+
def __init__(self, size_mb: float = 2.0):
31+
self.model = FakeLinearModel(size_mb)
32+
self._loaded_size = int(size_mb * 1024 * 1024)
33+
34+
def loaded_size(self):
35+
return self._loaded_size
36+
37+
def is_dynamic(self):
38+
return False
39+
40+
41+
class FakeLoadedModel:
42+
"""Mimics LoadedModel entries in current_loaded_models."""
43+
44+
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
45+
self._model = patcher
46+
self.currently_used = currently_used
47+
48+
@property
49+
def model(self):
50+
return self._model
51+
52+
def model_memory(self):
53+
return self._model.loaded_size()
54+
55+
def model_unload(self, _memory_to_free):
56+
return True
57+
58+
def model_load(self, _device, _keep_loaded):
59+
pass
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# 1. BasicCache.clear_all()
64+
# ---------------------------------------------------------------------------
65+
66+
class TestBasicCacheClearAll:
67+
"""Verify that BasicCache.clear_all() is a proper public API."""
68+
69+
def test_clear_all_empties_cache_and_subcaches(self):
70+
"""clear_all() must remove every entry in both dicts."""
71+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
72+
73+
cache = BasicCache(CacheKeySetInputSignature)
74+
cache.cache["key1"] = "value1"
75+
cache.cache["key2"] = "value2"
76+
cache.subcaches["sub1"] = "subvalue1"
77+
78+
cache.clear_all()
79+
80+
assert len(cache.cache) == 0
81+
assert len(cache.subcaches) == 0
82+
83+
def test_clear_all_is_idempotent(self):
84+
"""Calling clear_all() on an already-empty cache must not raise."""
85+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
86+
87+
cache = BasicCache(CacheKeySetInputSignature)
88+
cache.clear_all() # should be a no-op
89+
cache.clear_all() # still a no-op
90+
91+
assert len(cache.cache) == 0
92+
93+
def test_null_cache_clear_all_is_noop(self):
94+
"""NullCache.clear_all() must not raise — it's the null backend."""
95+
from comfy_execution.caching import NullCache
96+
97+
null = NullCache()
98+
null.clear_all() # must not raise AttributeError
99+
100+
def test_lru_cache_clear_all_resets_metadata(self):
101+
"""LRUCache.clear_all() must also reset used_generation, children, min_generation."""
102+
from comfy_execution.caching import LRUCache, CacheKeySetInputSignature
103+
104+
cache = LRUCache(CacheKeySetInputSignature, max_size=10)
105+
# Simulate some entries
106+
cache.cache["k1"] = "v1"
107+
cache.used_generation["k1"] = 5
108+
cache.children["k1"] = ["child1"]
109+
cache.min_generation = 3
110+
cache.generation = 5
111+
112+
cache.clear_all()
113+
114+
assert len(cache.cache) == 0
115+
assert len(cache.used_generation) == 0
116+
assert len(cache.children) == 0
117+
assert cache.min_generation == 0
118+
# generation counter should NOT be reset (it's a monotonic counter)
119+
assert cache.generation == 5
120+
121+
def test_ram_pressure_cache_clear_all_resets_timestamps(self):
122+
"""RAMPressureCache.clear_all() must also reset timestamps."""
123+
from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature
124+
125+
cache = RAMPressureCache(CacheKeySetInputSignature)
126+
cache.cache["k1"] = "v1"
127+
cache.used_generation["k1"] = 2
128+
cache.timestamps["k1"] = 1234567890.0
129+
130+
cache.clear_all()
131+
132+
assert len(cache.cache) == 0
133+
assert len(cache.used_generation) == 0
134+
assert len(cache.timestamps) == 0
135+
136+
137+
# ---------------------------------------------------------------------------
138+
# 2. Callback registration & dispatch
139+
# ---------------------------------------------------------------------------
140+
141+
class TestModelDestroyedCallbacks:
142+
"""Validate the on_model_destroyed lifecycle callback system."""
143+
144+
def setup_method(self):
145+
"""Reset the callback list before every test."""
146+
import comfy.model_management as mm
147+
self._original = mm._on_model_destroyed_callbacks.copy()
148+
mm._on_model_destroyed_callbacks.clear()
149+
150+
def teardown_method(self):
151+
"""Restore the original callback list."""
152+
import comfy.model_management as mm
153+
mm._on_model_destroyed_callbacks.clear()
154+
mm._on_model_destroyed_callbacks.extend(self._original)
155+
156+
def test_register_single_callback(self):
157+
import comfy.model_management as mm
158+
159+
invocations = []
160+
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
161+
162+
assert len(mm._on_model_destroyed_callbacks) == 1
163+
164+
# Simulate dispatch
165+
for cb in mm._on_model_destroyed_callbacks:
166+
cb("test")
167+
assert invocations == ["test"]
168+
169+
def test_register_multiple_callbacks(self):
170+
"""Multiple registrants must all fire — no silent overwrites."""
171+
import comfy.model_management as mm
172+
173+
results_a, results_b = [], []
174+
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
175+
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
176+
177+
for cb in mm._on_model_destroyed_callbacks:
178+
cb("batch")
179+
180+
assert results_a == ["batch"]
181+
assert results_b == ["batch"]
182+
183+
def test_callback_receives_reason_string(self):
184+
"""The callback signature is (reason: str) -> None."""
185+
import comfy.model_management as mm
186+
187+
captured = {}
188+
def _cb(reason):
189+
captured["reason"] = reason
190+
captured["type"] = type(reason).__name__
191+
192+
mm.register_model_destroyed_callback(_cb)
193+
for cb in mm._on_model_destroyed_callbacks:
194+
cb("batch")
195+
196+
assert captured["reason"] == "batch"
197+
assert captured["type"] == "str"
198+
199+
200+
# ---------------------------------------------------------------------------
201+
# 3. Meta-device destruction threshold
202+
# ---------------------------------------------------------------------------
203+
204+
class TestMetaDeviceThreshold:
205+
"""Verify that only models > 1 GB are queued for meta-device destruction."""
206+
207+
def test_small_model_not_destroyed(self):
208+
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
209+
model = FakeLinearModel(size_mb=160)
210+
211+
# Simulate the threshold check from free_memory
212+
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
213+
threshold = 1024 * 1024 * 1024 # 1 GB
214+
215+
assert model_size < threshold, (
216+
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
217+
)
218+
# Confirm parameters are still on a real device
219+
assert model.weight.device.type != "meta"
220+
221+
def test_large_model_above_threshold(self):
222+
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
223+
# Use a meta-device tensor to avoid allocating 2 GB of real memory.
224+
# Meta tensors report correct numel/element_size but use zero storage.
225+
n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params
226+
meta_weight = torch.empty(n, dtype=torch.float32, device="meta")
227+
228+
model_size = meta_weight.numel() * meta_weight.element_size()
229+
threshold = 1024 * 1024 * 1024 # 1 GB
230+
231+
assert model_size > threshold, (
232+
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
233+
)
234+
235+
def test_meta_device_move_releases_storage(self):
236+
"""Moving parameters to 'meta' must place them on the meta device."""
237+
model = FakeLinearModel(size_mb=2)
238+
assert model.weight.device.type != "meta"
239+
240+
model.to(device="meta")
241+
242+
assert model.weight.device.type == "meta"
243+
# Meta tensors retain their logical shape but live on a virtual device
244+
# with no physical backing — this is what releases RAM.
245+
assert model.weight.nelement() > 0 # still has logical shape
246+
assert model.weight.untyped_storage().device.type == "meta"
247+
248+
249+
# ---------------------------------------------------------------------------
250+
# 4. MPS flush conditionality
251+
# ---------------------------------------------------------------------------
252+
253+
class TestMpsFlushConditionality:
254+
"""Verify the MPS flush only activates under correct conditions."""
255+
256+
def test_flush_requires_aggressive_offload_flag(self):
257+
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
258+
import comfy.model_management as mm
259+
260+
# When False, flush should NOT be injected
261+
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
262+
try:
263+
mm.AGGRESSIVE_OFFLOAD = False
264+
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
265+
266+
mm.AGGRESSIVE_OFFLOAD = True
267+
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
268+
finally:
269+
mm.AGGRESSIVE_OFFLOAD = original
270+
271+
def test_flush_requires_mps_device(self):
272+
"""The flush condition checks device.type == 'mps'."""
273+
# Simulate CPU device — flush should not activate
274+
cpu_device = torch.device("cpu")
275+
assert cpu_device.type != "mps"
276+
277+
# Simulate MPS device string check
278+
if torch.backends.mps.is_available():
279+
mps_device = torch.device("mps")
280+
assert mps_device.type == "mps"
281+
282+
283+
# ---------------------------------------------------------------------------
284+
# 5. AGGRESSIVE_OFFLOAD flag integration
285+
# ---------------------------------------------------------------------------
286+
287+
class TestAggressiveOffloadFlag:
288+
"""Verify the CLI flag is correctly exposed."""
289+
290+
def test_flag_exists_in_model_management(self):
291+
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
292+
import comfy.model_management as mm
293+
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
294+
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
295+
296+
def test_flag_defaults_from_cli_args(self):
297+
"""The flag should be wired from cli_args to model_management."""
298+
import comfy.cli_args as cli_args
299+
import comfy.model_management as mm
300+
assert hasattr(cli_args.args, "aggressive_offload")
301+
assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload

0 commit comments

Comments
 (0)