Skip to content

Commit d3881f3

Browse files
Hzzonesayakpaul
andauthored
Gligen training (huggingface#7906)
* add training code of gligen * fix code quality tests. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 48207d6 commit d3881f3

File tree

7 files changed

+1312
-0
lines changed

7 files changed

+1312
-0
lines changed
+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# GLIGEN: Open-Set Grounded Text-to-Image Generation
2+
3+
These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.
4+
5+
### Install the requirements
6+
7+
```bash
8+
conda create -n diffusers python==3.10
9+
conda activate diffusers
10+
pip install -r requirements.txt
11+
```
12+
13+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
14+
15+
```bash
16+
accelerate config
17+
```
18+
19+
Or for a default accelerate configuration without answering questions about your environment
20+
21+
```bash
22+
accelerate config default
23+
```
24+
25+
Or if your environment doesn't support an interactive shell e.g. a notebook
26+
27+
```python
28+
from accelerate.utils import write_basic_config
29+
30+
write_basic_config()
31+
```
32+
33+
### Prepare the training data
34+
35+
If you want to make your own grounding data, you need to install the requirements.
36+
37+
I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag
38+
images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,
39+
and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.
40+
41+
Only RAM needs to be installed manually:
42+
43+
```bash
44+
pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
45+
```
46+
47+
Download the pre-trained model:
48+
49+
```bash
50+
huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
51+
huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
52+
huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
53+
huggingface-cli download --resume-download clip-vit-large-patch14
54+
huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
55+
```
56+
57+
Make the training data on 8 GPUs:
58+
59+
```bash
60+
torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
61+
--data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
62+
--save_root /root/gligen_data \
63+
--ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth
64+
```
65+
66+
You can download the COCO training data from
67+
68+
```bash
69+
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
70+
```
71+
72+
It's in the format of
73+
74+
```json
75+
[
76+
...
77+
{
78+
'file_path': Path,
79+
'annos': [
80+
{
81+
'caption': Instance
82+
Caption,
83+
'bbox': bbox
84+
in
85+
xyxy,
86+
'text_embeddings_before_projection': CLIP
87+
text
88+
embedding
89+
before
90+
linear
91+
projection
92+
}
93+
]
94+
}
95+
...
96+
]
97+
```
98+
99+
### Training commands
100+
101+
The training script is heavily based
102+
on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py
103+
104+
```bash
105+
accelerate launch train_gligen_text.py \
106+
--data_path /root/data/zhizhonghuang/coco_train2017.pth \
107+
--image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
108+
--train_batch_size 8 \
109+
--max_train_steps 100000 \
110+
--checkpointing_steps 1000 \
111+
--checkpoints_total_limit 10 \
112+
--learning_rate 5e-5 \
113+
--dataloader_num_workers 16 \
114+
--mixed_precision fp16 \
115+
--report_to wandb \
116+
--tracker_project_name gligen \
117+
--output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO
118+
```
119+
120+
I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the
121+
layout possibly at 50k iterations.
122+
123+
Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)
124+
125+
The trained model can be downloaded from
126+
127+
```bash
128+
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
129+
```
130+
131+
You can run `demo.ipynb` to visualize the generated images.
132+
133+
Example prompts:
134+
135+
```python
136+
prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'
137+
boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],
138+
[0.525390625, 0.552734375, 0.93359375, 0.865234375],
139+
[0.12890625, 0.015625, 0.412109375, 0.279296875],
140+
[0.578125, 0.08203125, 0.857421875, 0.27734375]]
141+
gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']
142+
```
143+
144+
Example images:
145+
![alt text](generated-images-100000-00.png)
146+
147+
### Citation
148+
149+
```
150+
@article{li2023gligen,
151+
title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
152+
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
153+
journal={CVPR},
154+
year={2023}
155+
}
156+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import os
2+
import random
3+
4+
import torch
5+
import torchvision.transforms as transforms
6+
from PIL import Image
7+
8+
9+
def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):
10+
scale = image_size / min(original_image_size)
11+
crop_y = (original_image_size[1] * scale - image_size) // 2
12+
crop_x = (original_image_size[0] * scale - image_size) // 2
13+
x0 = max(x * scale - crop_x, 0)
14+
y0 = max(y * scale - crop_y, 0)
15+
x1 = min((x + w) * scale - crop_x, image_size)
16+
y1 = min((y + h) * scale - crop_y, image_size)
17+
if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:
18+
return False, (None, None, None, None)
19+
return True, (x0, y0, x1, y1)
20+
21+
22+
class COCODataset(torch.utils.data.Dataset):
23+
def __init__(
24+
self,
25+
data_path,
26+
image_path,
27+
image_size=512,
28+
min_box_size=0.01,
29+
max_boxes_per_data=8,
30+
tokenizer=None,
31+
):
32+
super().__init__()
33+
self.min_box_size = min_box_size
34+
self.max_boxes_per_data = max_boxes_per_data
35+
self.image_size = image_size
36+
self.image_path = image_path
37+
self.tokenizer = tokenizer
38+
self.transforms = transforms.Compose(
39+
[
40+
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
41+
transforms.CenterCrop(image_size),
42+
transforms.ToTensor(),
43+
transforms.Normalize([0.5], [0.5]),
44+
]
45+
)
46+
47+
self.data_list = torch.load(data_path, map_location="cpu")
48+
49+
def __getitem__(self, index):
50+
if self.max_boxes_per_data > 99:
51+
assert False, "Are you sure setting such large number of boxes per image?"
52+
53+
out = {}
54+
55+
data = self.data_list[index]
56+
image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB")
57+
original_image_size = image.size
58+
out["pixel_values"] = self.transforms(image)
59+
60+
annos = data["annos"]
61+
62+
areas, valid_annos = [], []
63+
for anno in annos:
64+
# x, y, w, h = anno['bbox']
65+
x0, y0, x1, y1 = anno["bbox"]
66+
x, y, w, h = x0, y0, x1 - x0, y1 - y0
67+
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(
68+
x, y, w, h, self.image_size, original_image_size, self.min_box_size
69+
)
70+
if valid:
71+
anno["bbox"] = [x0, y0, x1, y1]
72+
areas.append((x1 - x0) * (y1 - y0))
73+
valid_annos.append(anno)
74+
75+
# Sort according to area and choose the largest N objects
76+
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
77+
wanted_idxs = wanted_idxs[: self.max_boxes_per_data]
78+
valid_annos = [valid_annos[i] for i in wanted_idxs]
79+
80+
out["boxes"] = torch.zeros(self.max_boxes_per_data, 4)
81+
out["masks"] = torch.zeros(self.max_boxes_per_data)
82+
out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768)
83+
84+
for i, anno in enumerate(valid_annos):
85+
out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size
86+
out["masks"][i] = 1
87+
out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"]
88+
89+
prob_drop_boxes = 0.1
90+
if random.random() < prob_drop_boxes:
91+
out["masks"][:] = 0
92+
93+
caption = random.choice(data["captions"])
94+
95+
prob_drop_captions = 0.5
96+
if random.random() < prob_drop_captions:
97+
caption = ""
98+
caption = self.tokenizer(
99+
caption,
100+
max_length=self.tokenizer.model_max_length,
101+
padding="max_length",
102+
truncation=True,
103+
return_tensors="pt",
104+
)
105+
out["caption"] = caption
106+
107+
return out
108+
109+
def __len__(self):
110+
return len(self.data_list)

0 commit comments

Comments
 (0)