Skip to content

Commit 66a5279

Browse files
patil-surajanton-lpcuenca
authored
stable diffusion fine-tuning (#356)
* begin text2image script * loading the datasets, preprocessing & transforms * handle input features correctly * add gradient checkpointing support * fix output names * run unet in train mode not text encoder * use no_grad instead of freezing params * default max steps None * pad to longest * don't pad when tokenizing * fix encode on multi gpu * fix stupid bug * add random flip * add ema * fix ema * put ema on cpu * improve EMA model * contiguous_format * don't warp vae and text encode in accelerate * remove no_grad * use randn_like * fix resize * improve few things * log epoch loss * set log level * don't log each step * remove max_length from collate * style * add report_to option * make scale_lr false by default * add grad clipping * add an option to use 8bit adam * fix logging in multi-gpu, log every step * more comments * remove eval for now * adress review comments * add requirements file * begin readme * begin readme * fix typo * fix push to hub * populate readme * update readme * remove use_auth_token from the script * address some review comments * better mixed precision support * remove redundant to * create ema model early * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * better description for train_data_dir * add diffusers in requirements * update dataset_name_mapping * update readme * add inference example Co-authored-by: anton-l <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 797b290 commit 66a5279

File tree

3 files changed

+729
-0
lines changed

3 files changed

+729
-0
lines changed

examples/text_to_image/README.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Stable Diffusion text-to-image fine-tuning
2+
3+
The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
4+
5+
___Note___:
6+
7+
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
8+
9+
10+
## Running locally
11+
### Installing the dependencies
12+
13+
Before running the scripts, make sure to install the library's training dependencies:
14+
15+
```bash
16+
pip install git+https://github.com/huggingface/diffusers.git
17+
pip install -U -r requirements.txt
18+
```
19+
20+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
21+
22+
```bash
23+
accelerate config
24+
```
25+
26+
### Pokemon example
27+
28+
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
29+
30+
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
31+
32+
Run the following command to authenticate your token
33+
34+
```bash
35+
huggingface-cli login
36+
```
37+
38+
If you have already cloned the repo, then you won't need to go through these steps.
39+
40+
<br>
41+
42+
#### Hardware
43+
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
44+
45+
```bash
46+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
47+
export dataset_name="lambdalabs/pokemon-blip-captions"
48+
49+
accelerate launch train_text_to_image.py \
50+
--pretrained_model_name_or_path=$MODEL_NAME \
51+
--dataset_name=$dataset_name \
52+
--use_ema \
53+
--resolution=512 --center_crop --random_flip \
54+
--train_batch_size=1 \
55+
--gradient_accumulation_steps=4 \
56+
--gradient_checkpointing \
57+
--mixed_precision="fp16" \
58+
--max_train_steps=15000 \
59+
--learning_rate=1e-05 \
60+
--max_grad_norm=1 \
61+
--lr_scheduler="constant" --lr_warmup_steps=0 \
62+
--output_dir="sd-pokemon-model"
63+
```
64+
65+
66+
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
67+
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
68+
69+
```bash
70+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
71+
export TRAIN_DIR="path_to_your_dataset"
72+
73+
accelerate launch train_text_to_image.py \
74+
--pretrained_model_name_or_path=$MODEL_NAME \
75+
--train_data_dir=$TRAIN_DIR \
76+
--use_ema \
77+
--resolution=512 --center_crop --random_flip \
78+
--train_batch_size=1 \
79+
--gradient_accumulation_steps=4 \
80+
--gradient_checkpointing \
81+
--mixed_precision="fp16" \
82+
--max_train_steps=15000 \
83+
--learning_rate=1e-05 \
84+
--max_grad_norm=1 \
85+
--lr_scheduler="constant" --lr_warmup_steps=0 \
86+
--output_dir="sd-pokemon-model"
87+
```
88+
89+
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
90+
91+
92+
```python
93+
from diffusers import StableDiffusionPipeline
94+
95+
model_path = "path_to_saved_model"
96+
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
97+
pipe.to("cuda")
98+
99+
image = pipe(prompt="yoda").images[0]
100+
image.save("yoda-pokemon.png")
101+
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
diffusers==0.4.1
2+
accelerate
3+
torchvision
4+
transformers>=4.21.0
5+
ftfy
6+
tensorboard
7+
modelcards

0 commit comments

Comments
 (0)