Skip to content

Commit d9a4089

Browse files
committed
test: add unit tests for --aggressive-offload (12 tests)
1 parent 2c8db00 commit d9a4089

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

comfy_execution/caching.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def all_node_ids(self):
428428
def clean_unused(self):
429429
pass
430430

431+
def clear_all(self):
432+
"""No-op: null backend has nothing to invalidate."""
433+
pass
434+
431435
def poll(self, **kwargs):
432436
pass
433437

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Tests for the aggressive-offload memory management feature.
2+
3+
These tests validate the Apple Silicon (MPS) memory optimisation path without
4+
requiring a GPU or actual model weights. Every test mocks the relevant model
5+
and cache structures so the suite can run in CI on any platform.
6+
"""
7+
8+
import pytest
9+
import types
10+
import torch
11+
import torch.nn as nn
12+
13+
# ---------------------------------------------------------------------------
14+
# Fixtures & helpers
15+
# ---------------------------------------------------------------------------
16+
17+
class FakeLinearModel(nn.Module):
18+
"""Minimal nn.Module whose parameters consume measurable memory."""
19+
20+
def __init__(self, size_mb: float = 2.0):
21+
super().__init__()
22+
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
23+
n = int(size_mb * 1024 * 1024 / 4)
24+
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
25+
26+
27+
class FakeModelPatcher:
28+
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
29+
30+
def __init__(self, size_mb: float = 2.0):
31+
self.model = FakeLinearModel(size_mb)
32+
self._loaded_size = int(size_mb * 1024 * 1024)
33+
34+
def loaded_size(self):
35+
return self._loaded_size
36+
37+
def is_dynamic(self):
38+
return False
39+
40+
41+
class FakeLoadedModel:
42+
"""Mimics LoadedModel entries in current_loaded_models."""
43+
44+
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
45+
self._model = patcher
46+
self.currently_used = currently_used
47+
48+
@property
49+
def model(self):
50+
return self._model
51+
52+
def model_memory(self):
53+
return self._model.loaded_size()
54+
55+
def model_unload(self, _memory_to_free):
56+
return True
57+
58+
def model_load(self, _device, _keep_loaded):
59+
pass
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# 1. BasicCache.clear_all()
64+
# ---------------------------------------------------------------------------
65+
66+
class TestBasicCacheClearAll:
67+
"""Verify that BasicCache.clear_all() is a proper public API."""
68+
69+
def test_clear_all_empties_cache_and_subcaches(self):
70+
"""clear_all() must remove every entry in both dicts."""
71+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
72+
73+
cache = BasicCache(CacheKeySetInputSignature)
74+
cache.cache["key1"] = "value1"
75+
cache.cache["key2"] = "value2"
76+
cache.subcaches["sub1"] = "subvalue1"
77+
78+
cache.clear_all()
79+
80+
assert len(cache.cache) == 0
81+
assert len(cache.subcaches) == 0
82+
83+
def test_clear_all_is_idempotent(self):
84+
"""Calling clear_all() on an already-empty cache must not raise."""
85+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
86+
87+
cache = BasicCache(CacheKeySetInputSignature)
88+
cache.clear_all() # should be a no-op
89+
cache.clear_all() # still a no-op
90+
91+
assert len(cache.cache) == 0
92+
93+
def test_null_cache_clear_all_is_noop(self):
94+
"""NullCache.clear_all() must not raise — it's the null backend."""
95+
from comfy_execution.caching import NullCache
96+
97+
null = NullCache()
98+
null.clear_all() # must not raise AttributeError
99+
100+
101+
# ---------------------------------------------------------------------------
102+
# 2. Callback registration & dispatch
103+
# ---------------------------------------------------------------------------
104+
105+
class TestModelDestroyedCallbacks:
106+
"""Validate the on_model_destroyed lifecycle callback system."""
107+
108+
def setup_method(self):
109+
"""Reset the callback list before every test."""
110+
import comfy.model_management as mm
111+
self._original = mm._on_model_destroyed_callbacks.copy()
112+
mm._on_model_destroyed_callbacks.clear()
113+
114+
def teardown_method(self):
115+
"""Restore the original callback list."""
116+
import comfy.model_management as mm
117+
mm._on_model_destroyed_callbacks.clear()
118+
mm._on_model_destroyed_callbacks.extend(self._original)
119+
120+
def test_register_single_callback(self):
121+
import comfy.model_management as mm
122+
123+
invocations = []
124+
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
125+
126+
assert len(mm._on_model_destroyed_callbacks) == 1
127+
128+
# Simulate dispatch
129+
for cb in mm._on_model_destroyed_callbacks:
130+
cb("test")
131+
assert invocations == ["test"]
132+
133+
def test_register_multiple_callbacks(self):
134+
"""Multiple registrants must all fire — no silent overwrites."""
135+
import comfy.model_management as mm
136+
137+
results_a, results_b = [], []
138+
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
139+
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
140+
141+
for cb in mm._on_model_destroyed_callbacks:
142+
cb("batch")
143+
144+
assert results_a == ["batch"]
145+
assert results_b == ["batch"]
146+
147+
def test_callback_receives_reason_string(self):
148+
"""The callback signature is (reason: str) -> None."""
149+
import comfy.model_management as mm
150+
151+
captured = {}
152+
def _cb(reason):
153+
captured["reason"] = reason
154+
captured["type"] = type(reason).__name__
155+
156+
mm.register_model_destroyed_callback(_cb)
157+
for cb in mm._on_model_destroyed_callbacks:
158+
cb("batch")
159+
160+
assert captured["reason"] == "batch"
161+
assert captured["type"] == "str"
162+
163+
164+
# ---------------------------------------------------------------------------
165+
# 3. Meta-device destruction threshold
166+
# ---------------------------------------------------------------------------
167+
168+
class TestMetaDeviceThreshold:
169+
"""Verify that only models > 1 GB are queued for meta-device destruction."""
170+
171+
def test_small_model_not_destroyed(self):
172+
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
173+
model = FakeLinearModel(size_mb=160)
174+
175+
# Simulate the threshold check from free_memory
176+
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
177+
threshold = 1024 * 1024 * 1024 # 1 GB
178+
179+
assert model_size < threshold, (
180+
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
181+
)
182+
# Confirm parameters are still on a real device
183+
assert model.weight.device.type != "meta"
184+
185+
def test_large_model_above_threshold(self):
186+
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
187+
model = FakeLinearModel(size_mb=2048)
188+
189+
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
190+
threshold = 1024 * 1024 * 1024 # 1 GB
191+
192+
assert model_size > threshold, (
193+
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
194+
)
195+
196+
def test_meta_device_move_releases_storage(self):
197+
"""Moving parameters to 'meta' must place them on the meta device."""
198+
model = FakeLinearModel(size_mb=2)
199+
assert model.weight.device.type != "meta"
200+
201+
model.to(device="meta")
202+
203+
assert model.weight.device.type == "meta"
204+
# Meta tensors retain their logical shape but live on a virtual device
205+
# with no physical backing — this is what releases RAM.
206+
assert model.weight.nelement() > 0 # still has logical shape
207+
assert model.weight.untyped_storage().device.type == "meta"
208+
209+
210+
# ---------------------------------------------------------------------------
211+
# 4. MPS flush conditionality
212+
# ---------------------------------------------------------------------------
213+
214+
class TestMpsFlushConditionality:
215+
"""Verify the MPS flush only activates under correct conditions."""
216+
217+
def test_flush_requires_aggressive_offload_flag(self):
218+
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
219+
import comfy.model_management as mm
220+
221+
# When False, flush should NOT be injected
222+
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
223+
try:
224+
mm.AGGRESSIVE_OFFLOAD = False
225+
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
226+
227+
mm.AGGRESSIVE_OFFLOAD = True
228+
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
229+
finally:
230+
mm.AGGRESSIVE_OFFLOAD = original
231+
232+
def test_flush_requires_mps_device(self):
233+
"""The flush condition checks device.type == 'mps'."""
234+
# Simulate CPU device — flush should not activate
235+
cpu_device = torch.device("cpu")
236+
assert cpu_device.type != "mps"
237+
238+
# Simulate MPS device string check
239+
if torch.backends.mps.is_available():
240+
mps_device = torch.device("mps")
241+
assert mps_device.type == "mps"
242+
243+
244+
# ---------------------------------------------------------------------------
245+
# 5. AGGRESSIVE_OFFLOAD flag integration
246+
# ---------------------------------------------------------------------------
247+
248+
class TestAggressiveOffloadFlag:
249+
"""Verify the CLI flag is correctly exposed."""
250+
251+
def test_flag_exists_in_model_management(self):
252+
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
253+
import comfy.model_management as mm
254+
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
255+
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
256+
257+
def test_flag_defaults_from_cli_args(self):
258+
"""The flag should be sourced from cli_args."""
259+
import comfy.cli_args as cli_args
260+
assert hasattr(cli_args, "args")
261+
assert hasattr(cli_args.args, "aggressive_offload")

0 commit comments

Comments
 (0)