Skip to content

Commit b2e2d14

Browse files
authored
minimal stable diffusion GPU memory usage with accelerate hooks (apple#850)
* add method to enable cuda with minimal gpu usage to stable diffusion * add test to minimal cuda memory usage * ensure all models but unet are onn torch.float32 * move to cpu_offload along with minor internal changes to make it work * make it test against accelerate master branch * coming back, its official: I don't know how to make it test againt the master branch from accelerate * make it install accelerate from master on tests * go back to accelerate>=0.11 * undo prettier formatting on yml files * undo prettier formatting on yml files againn
1 parent 2f0fcf4 commit b2e2d14

File tree

5 files changed

+39
-2
lines changed

5 files changed

+39
-2
lines changed

.github/workflows/pr_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
python -m pip install --upgrade pip
3535
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
3636
python -m pip install -e .[quality,test]
37+
python -m pip install git+https://github.com/huggingface/accelerate
3738
3839
- name: Environment
3940
run: |
@@ -80,6 +81,7 @@ jobs:
8081
${CONDA_RUN} python -m pip install --upgrade pip
8182
${CONDA_RUN} python -m pip install -e .[quality,test]
8283
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
84+
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
8385
8486
- name: Environment
8587
shell: arch -arch arm64 bash {0}

.github/workflows/push_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
python -m pip uninstall -y torch torchvision torchtext
3737
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
3838
python -m pip install -e .[quality,test]
39+
python -m pip install git+https://github.com/huggingface/accelerate
3940
4041
- name: Environment
4142
run: |
@@ -58,8 +59,6 @@ jobs:
5859
name: torch_test_reports
5960
path: reports
6061

61-
62-
6362
run_examples_single_gpu:
6463
name: Examples tests
6564
runs-on: [ self-hosted, docker-gpu, single-gpu ]
@@ -83,6 +82,7 @@ jobs:
8382
python -m pip uninstall -y torch torchvision torchtext
8483
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
8584
python -m pip install -e .[quality,test,training]
85+
python -m pip install git+https://github.com/huggingface/accelerate
8686
8787
- name: Environment
8888
run: |

src/diffusers/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def device(self) -> torch.device:
223223
for name in module_names.keys():
224224
module = getattr(self, name)
225225
if isinstance(module, torch.nn.Module):
226+
if module.device == torch.device("meta"):
227+
return torch.device("cpu")
226228
return module.device
227229
return torch.device("cpu")
228230

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from diffusers.utils import is_accelerate_available
67
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
78

89
from ...configuration_utils import FrozenDict
@@ -118,6 +119,18 @@ def disable_attention_slicing(self):
118119
# set slice_size = `None` to disable `attention slicing`
119120
self.enable_attention_slicing(None)
120121

122+
def cuda_with_minimal_gpu_usage(self):
123+
if is_accelerate_available():
124+
from accelerate import cpu_offload
125+
else:
126+
raise ImportError("Please install accelerate via `pip install accelerate`")
127+
128+
device = torch.device("cuda")
129+
self.enable_attention_slicing(1)
130+
131+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
132+
cpu_offload(cpu_offloaded_model, device)
133+
121134
@torch.no_grad()
122135
def __call__(
123136
self,

tests/test_pipelines.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,23 @@ def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
535535
tracemalloc.stop()
536536

537537
assert peak_accelerate < peak_normal
538+
539+
@slow
540+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
541+
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
542+
torch.cuda.empty_cache()
543+
torch.cuda.reset_max_memory_allocated()
544+
545+
pipeline_id = "CompVis/stable-diffusion-v1-4"
546+
prompt = "Andromeda galaxy in a bottle"
547+
548+
pipeline = StableDiffusionPipeline.from_pretrained(
549+
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
550+
)
551+
pipeline.cuda_with_minimal_gpu_usage()
552+
553+
_ = pipeline(prompt)
554+
555+
mem_bytes = torch.cuda.max_memory_allocated()
556+
# make sure that less than 0.8 GB is allocated
557+
assert mem_bytes < 0.8 * 10**9

0 commit comments

Comments
 (0)