Skip to content

Commit ec37e20

Browse files
faaanyhlky
andauthored
[tests] make tests device-agnostic (part 3) (#10437)
* initial comit * fix empty cache * fix one more * fix style * update device functions * update * update * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/controlnet/test_controlnet.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/controlnet/test_controlnet.py Co-authored-by: hlky <[email protected]> * with gc.collect * update * make style * check_torch_dependencies * add mps empty cache * bug fix * Apply suggestions from code review --------- Co-authored-by: hlky <[email protected]>
1 parent 158a5a8 commit ec37e20

26 files changed

+275
-170
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@
8686
) from e
8787
logger.info(f"torch_device overrode to {torch_device}")
8888
else:
89-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
89+
if torch.cuda.is_available():
90+
torch_device = "cuda"
91+
elif torch.xpu.is_available():
92+
torch_device = "xpu"
93+
else:
94+
torch_device = "cpu"
9095
is_torch_higher_equal_than_1_12 = version.parse(
9196
version.parse(torch.__version__).base_version
9297
) >= version.parse("1.12")
@@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device):
10671072
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
10681073
if is_torch_available():
10691074
# Behaviour flags
1070-
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
1075+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
10711076

10721077
# Function definitions
1073-
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
1074-
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
1075-
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
1078+
BACKEND_EMPTY_CACHE = {
1079+
"cuda": torch.cuda.empty_cache,
1080+
"xpu": torch.xpu.empty_cache,
1081+
"cpu": None,
1082+
"mps": torch.mps.empty_cache,
1083+
"default": None,
1084+
}
1085+
BACKEND_DEVICE_COUNT = {
1086+
"cuda": torch.cuda.device_count,
1087+
"xpu": torch.xpu.device_count,
1088+
"cpu": lambda: 0,
1089+
"mps": lambda: 0,
1090+
"default": 0,
1091+
}
1092+
BACKEND_MANUAL_SEED = {
1093+
"cuda": torch.cuda.manual_seed,
1094+
"xpu": torch.xpu.manual_seed,
1095+
"cpu": torch.manual_seed,
1096+
"mps": torch.mps.manual_seed,
1097+
"default": torch.manual_seed,
1098+
}
1099+
BACKEND_RESET_PEAK_MEMORY_STATS = {
1100+
"cuda": torch.cuda.reset_peak_memory_stats,
1101+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
1102+
"cpu": None,
1103+
"mps": None,
1104+
"default": None,
1105+
}
1106+
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
1107+
"cuda": torch.cuda.reset_max_memory_allocated,
1108+
"xpu": None,
1109+
"cpu": None,
1110+
"mps": None,
1111+
"default": None,
1112+
}
1113+
BACKEND_MAX_MEMORY_ALLOCATED = {
1114+
"cuda": torch.cuda.max_memory_allocated,
1115+
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
1116+
"cpu": 0,
1117+
"mps": 0,
1118+
"default": 0,
1119+
}
10761120

10771121

10781122
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1103,6 +1147,18 @@ def backend_device_count(device: str):
11031147
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
11041148

11051149

1150+
def backend_reset_peak_memory_stats(device: str):
1151+
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
1152+
1153+
1154+
def backend_reset_max_memory_allocated(device: str):
1155+
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
1156+
1157+
1158+
def backend_max_memory_allocated(device: str):
1159+
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
1160+
1161+
11061162
# These are callables which return boolean behaviour flags and can be used to specify some
11071163
# device agnostic alternative where the feature is unsupported.
11081164
def backend_supports_training(device: str):
@@ -1159,3 +1215,6 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
11591215
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
11601216
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
11611217
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
1218+
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
1219+
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
1220+
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")

tests/models/test_modeling_common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
get_python_version,
5858
is_torch_compile,
5959
require_torch_2,
60+
require_torch_accelerator,
6061
require_torch_accelerator_with_training,
61-
require_torch_gpu,
6262
require_torch_multi_gpu,
6363
run_test_in_subprocess,
6464
torch_all_close,
@@ -543,7 +543,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
543543
assert torch.allclose(output, output_3, atol=self.base_precision)
544544
assert torch.allclose(output_2, output_3, atol=self.base_precision)
545545

546-
@require_torch_gpu
546+
@require_torch_accelerator
547547
def test_set_attn_processor_for_determinism(self):
548548
if self.uses_custom_attn_processor:
549549
return
@@ -1068,7 +1068,7 @@ def test_wrong_adapter_name_raises_error(self):
10681068

10691069
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
10701070

1071-
@require_torch_gpu
1071+
@require_torch_accelerator
10721072
def test_cpu_offload(self):
10731073
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10741074
model = self.model_class(**config).eval()
@@ -1098,7 +1098,7 @@ def test_cpu_offload(self):
10981098

10991099
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11001100

1101-
@require_torch_gpu
1101+
@require_torch_accelerator
11021102
def test_disk_offload_without_safetensors(self):
11031103
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
11041104
model = self.model_class(**config).eval()
@@ -1132,7 +1132,7 @@ def test_disk_offload_without_safetensors(self):
11321132

11331133
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11341134

1135-
@require_torch_gpu
1135+
@require_torch_accelerator
11361136
def test_disk_offload_with_safetensors(self):
11371137
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
11381138
model = self.model_class(**config).eval()
@@ -1191,7 +1191,7 @@ def test_model_parallelism(self):
11911191

11921192
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11931193

1194-
@require_torch_gpu
1194+
@require_torch_accelerator
11951195
def test_sharded_checkpoints(self):
11961196
torch.manual_seed(0)
11971197
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1223,7 +1223,7 @@ def test_sharded_checkpoints(self):
12231223

12241224
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
12251225

1226-
@require_torch_gpu
1226+
@require_torch_accelerator
12271227
def test_sharded_checkpoints_with_variant(self):
12281228
torch.manual_seed(0)
12291229
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1261,7 +1261,7 @@ def test_sharded_checkpoints_with_variant(self):
12611261

12621262
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
12631263

1264-
@require_torch_gpu
1264+
@require_torch_accelerator
12651265
def test_sharded_checkpoints_device_map(self):
12661266
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12671267
model = self.model_class(**config).eval()

tests/pipelines/allegro/test_allegro.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
enable_full_determinism,
2828
numpy_cosine_similarity_distance,
2929
require_hf_hub_version_greater,
30-
require_torch_gpu,
30+
require_torch_accelerator,
3131
require_transformers_version_greater,
3232
slow,
3333
torch_device,
@@ -332,7 +332,7 @@ def test_save_load_dduf(self):
332332

333333

334334
@slow
335-
@require_torch_gpu
335+
@require_torch_accelerator
336336
class AllegroPipelineIntegrationTests(unittest.TestCase):
337337
prompt = "A painting of a squirrel eating a burger."
338338

@@ -350,7 +350,7 @@ def test_allegro(self):
350350
generator = torch.Generator("cpu").manual_seed(0)
351351

352352
pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
353-
pipe.enable_model_cpu_offload()
353+
pipe.enable_model_cpu_offload(device=torch_device)
354354
prompt = self.prompt
355355

356356
videos = pipe(

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from diffusers.models.attention import FreeNoiseTransformerBlock
2121
from diffusers.utils import is_xformers_available, logging
2222
from diffusers.utils.testing_utils import (
23+
backend_empty_cache,
2324
numpy_cosine_similarity_distance,
2425
require_accelerator,
25-
require_torch_gpu,
26+
require_torch_accelerator,
2627
slow,
2728
torch_device,
2829
)
@@ -547,19 +548,19 @@ def test_vae_slicing(self):
547548

548549

549550
@slow
550-
@require_torch_gpu
551+
@require_torch_accelerator
551552
class AnimateDiffPipelineSlowTests(unittest.TestCase):
552553
def setUp(self):
553554
# clean up the VRAM before each test
554555
super().setUp()
555556
gc.collect()
556-
torch.cuda.empty_cache()
557+
backend_empty_cache(torch_device)
557558

558559
def tearDown(self):
559560
# clean up the VRAM after each test
560561
super().tearDown()
561562
gc.collect()
562-
torch.cuda.empty_cache()
563+
backend_empty_cache(torch_device)
563564

564565
def test_animatediff(self):
565566
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
@@ -573,7 +574,7 @@ def test_animatediff(self):
573574
clip_sample=False,
574575
)
575576
pipe.enable_vae_slicing()
576-
pipe.enable_model_cpu_offload()
577+
pipe.enable_model_cpu_offload(device=torch_device)
577578
pipe.set_progress_bar_config(disable=None)
578579

579580
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
2626
numpy_cosine_similarity_distance,
27-
require_torch_gpu,
27+
require_torch_accelerator,
2828
slow,
2929
torch_device,
3030
)
@@ -321,7 +321,7 @@ def test_fused_qkv_projections(self):
321321

322322

323323
@slow
324-
@require_torch_gpu
324+
@require_torch_accelerator
325325
class CogVideoXPipelineIntegrationTests(unittest.TestCase):
326326
prompt = "A painting of a squirrel eating a burger."
327327

@@ -339,7 +339,7 @@ def test_cogvideox(self):
339339
generator = torch.Generator("cpu").manual_seed(0)
340340

341341
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
342-
pipe.enable_model_cpu_offload()
342+
pipe.enable_model_cpu_offload(device=torch_device)
343343
prompt = self.prompt
344344

345345
videos = pipe(

tests/pipelines/cogvideo/test_cogvideox_image2video.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
2525
from diffusers.utils import load_image
2626
from diffusers.utils.testing_utils import (
27+
backend_empty_cache,
2728
enable_full_determinism,
2829
numpy_cosine_similarity_distance,
29-
require_torch_gpu,
30+
require_torch_accelerator,
3031
slow,
3132
torch_device,
3233
)
@@ -344,25 +345,25 @@ def test_fused_qkv_projections(self):
344345

345346

346347
@slow
347-
@require_torch_gpu
348+
@require_torch_accelerator
348349
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
349350
prompt = "A painting of a squirrel eating a burger."
350351

351352
def setUp(self):
352353
super().setUp()
353354
gc.collect()
354-
torch.cuda.empty_cache()
355+
backend_empty_cache(torch_device)
355356

356357
def tearDown(self):
357358
super().tearDown()
358359
gc.collect()
359-
torch.cuda.empty_cache()
360+
backend_empty_cache(torch_device)
360361

361362
def test_cogvideox(self):
362363
generator = torch.Generator("cpu").manual_seed(0)
363364

364365
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
365-
pipe.enable_model_cpu_offload()
366+
pipe.enable_model_cpu_offload(device=torch_device)
366367

367368
prompt = self.prompt
368369
image = load_image(

tests/pipelines/cogview3/test_cogview3plus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
2626
numpy_cosine_similarity_distance,
27-
require_torch_gpu,
27+
require_torch_accelerator,
2828
slow,
2929
torch_device,
3030
)
@@ -232,7 +232,7 @@ def test_attention_slicing_forward_pass(
232232

233233

234234
@slow
235-
@require_torch_gpu
235+
@require_torch_accelerator
236236
class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
237237
prompt = "A painting of a squirrel eating a burger."
238238

@@ -250,7 +250,7 @@ def test_cogview3plus(self):
250250
generator = torch.Generator("cpu").manual_seed(0)
251251

252252
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16)
253-
pipe.enable_model_cpu_offload()
253+
pipe.enable_model_cpu_offload(device=torch_device)
254254
prompt = self.prompt
255255

256256
images = pipe(

0 commit comments

Comments
 (0)