Skip to content

Commit f6974a9

Browse files
committed
update training code and readme
1 parent 98db3bb commit f6974a9

33 files changed

+6691
-168
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,9 @@ cython_debug/
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161161

162-
.venv*
162+
.venv*
163+
data
164+
wandb
165+
result
166+
slurm
167+
slurm_test.sh
File renamed without changes.

README.md

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,80 @@
11
# Improving Language Understanding from Screenshots
22

3-
This repository contains the code, data, and models for paper "Improving Language Understanding from Screenshots". In this paper, we focus on improving the language understanding ability of "screenshot LM" (models that process everything -- including text -- within visual inputs) and propose patch-and-text prediction (PTP), a novel pre-training objective for screenshot LMs.
3+
This repository contains the code, data, and models for paper [Improving Language Understanding from Screenshots](https://arxiv.org/abs/2402.14073). In this paper, we focus on improving the language understanding ability of "screenshot LM" (models that process everything -- including text -- within visual inputs) and propose patch-and-text prediction (PTP), a novel pre-training objective for screenshot LMs.
44

5-
We are still working hard to clean up the code for the release. Please stay tuned and the code/model should be ready within a week!
5+
## Quick Links
6+
7+
- [Environment](#environment)
8+
- [Preparing the data](#preparing-the-data)
9+
- [Reproducing our pre-trained models](#reproducing-our-pre-trained-models)
10+
- [Downloading our models](#downloading-our-models)
11+
- [Fine-tuning PTP models](#fine-tuning-ptp-models)
12+
- [Bugs or Questions?](#bugs-or-questions)
13+
- [Citation](#citation)
14+
15+
16+
## Environment
17+
18+
Firstly, please install the latest compatible [PyTorch](https://pytorch.org).
19+
20+
Then, install all the required packages by running:
21+
```bash
22+
pip install -r requirements.txt
23+
```
24+
25+
We strongly recommend using the exact same `transformers` and `accelerate` versions for best reproducibility.
26+
27+
## Preparing the data
28+
29+
For our encode-decoder experiments and the train-from-scratch autoregressive screenshot LM experiments, we use Wikipedia+BookCorpus as the pre-training data. You can find the already-tokenized dataset from [this Huggingface website](https://huggingface.co/datasets/princeton-nlp/ptp_data). You can download the data by
30+
```bash
31+
git clone https://huggingface.co/datasets/princeton-nlp/ptp_data data
32+
```
33+
This folder contains four files
34+
* `wikibook_256_opt_tk_train.npy` and `wikibook_256_opt_tk_val.npy`: Wiki+Book using OPT tokenizer, 256 tokens per example (for encoder-decoder).
35+
* `wikibook_512_llama_tk_train.npy` and `wikibook_512_llama_tk_val.npy`: Wiki+Book using LLAMA tokenizer, 512 tokens per example (for train-from scratch autoregressive).
36+
37+
For continuing training [Sheared-llama](https://github.com/princeton-nlp/LLM-Shearing) to use screenshots, we use Sheared-llama's pipeline for processing [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T) data. Please follow [this guideline](https://github.com/princeton-nlp/LLM-Shearing/tree/main/llmshearing/data) for processing the data. Our example config will use `./data/sheared-llama-rp/for_ft` for continuing pre-training and `./data/sheared-llama-rp/eval` for evaluation.
38+
39+
40+
## Reproducing our pre-trained models
41+
42+
43+
To reproduce our models, run the following command (requires 8 GPUs):
44+
```bash
45+
NUM_GPU=8 bash run_multiple_gpus.sh {CONFIG PATH}
46+
```
47+
There are three example configs:
48+
* `run_configs/ptp.yaml`: our main PTP model (encoder-decoder).
49+
* `run_configs/screenshot-llama-380m.yaml`: train-from-scratch autoregressive.
50+
* `run_configs/screenshot-llama-1.3b-from-sheared-llama.yaml`: continuing pre-training sheared-llama.
51+
52+
You can also run the single-GPU command `run_single_gpu.sh` for testing. To ensure the same hyperparameters, you should adjust the per-GPU batch size (`per_device_train_batch_size`) or the gradient accumulation steps (`gradient_accumulation_steps`) accordingly if you are not using 8 GPUs or your GPUs cannot fit our preset batch sizes.
53+
54+
## Downloading our models
55+
56+
We provide the following pre-trained models on Huggingface:
57+
58+
* [princeton-nlp/ptp](https://huggingface.co/princeton-nlp/ptp)
59+
* [princeton-nlp/screenshot-llama-380m](https://huggingface.co/princeton-nlp/screenshot-llama-380m)
60+
* [princeton-nlp/screenshot-llama-1.3b-from-sheared-llama](https://huggingface.co/princeton-nlp/screenshot-llama-1.3b-from-sheared-llama)
61+
62+
## Fine-tuning PTP models
63+
64+
Coming soon!
65+
66+
## Bugs or questions?
67+
68+
If you have any questions related to the paper, feel free to email Tianyu (`[email protected]`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
69+
70+
## Citation
71+
72+
Please cite our paper if you use PTP in your work:
73+
74+
```bibtex
75+
@article{gao2024improving,
76+
title={Improving Language Understanding from Screenshots},
77+
author={Gao, Tianyu and Wang, Zirui and Bhaskar, Adithya and Chen, Danqi},
78+
year={2024}
79+
}
80+
```

data.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
from torch.utils.data import Dataset
2+
import numpy as np
3+
# import gi
4+
# gi.require_version('Pango', '1.0')
5+
# gi.require_version('PangoCairo', '1.0')
6+
# from gi.repository import Pango, PangoCairo
7+
# import cairo
8+
from PIL import Image
9+
from dataclasses import dataclass, field
10+
import torch
11+
from streaming import LocalDataset
12+
from image_utils import *
13+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
14+
from transformers.image_utils import to_numpy_array
15+
from modeling.span_masking import SpanMaskingGenerator
16+
from random import sample
17+
from image_utils import render_text
18+
19+
class NumpyDataset(Dataset):
20+
21+
def __init__(self, path, block_size=None):
22+
self.tokens = np.load(path)
23+
self.block_size = self.tokens.shape[1] if block_size is None else block_size
24+
self.font_size = None
25+
26+
def __len__(self):
27+
return len(self.tokens)
28+
29+
def __getitem__(self, idx):
30+
return {"tokens": self.tokens[idx][:self.block_size], "font_size": self.font_size}
31+
32+
33+
class RenderTextCollator:
34+
def __init__(self,
35+
processor: object,
36+
font_size: int,
37+
line_space: int,
38+
replace_new_line: bool,
39+
new_line_token: str,
40+
width: int,
41+
height: int,
42+
block_size: int = 1024,
43+
rendered_as_target: bool = False,
44+
patch_width: int = 16,
45+
patch_height: int = 16,
46+
text_mask_rate: float = 0,
47+
merge_text_masks: bool = False,
48+
ignore_white_patches: bool = False,
49+
add_black_patch: bool = False,
50+
add_prefix: bool = False,
51+
autoregressive: bool = False,
52+
ar_image_block_size: int = None,
53+
total_block_size: int = None,
54+
context_mask: int = None,
55+
image_mode: str = "RGB",
56+
sample_mask_at_collator: bool = False,
57+
mask_ratio: float = 0,
58+
span_masking: bool = False,
59+
max_span_length: int = 6,
60+
):
61+
self.processor = processor
62+
self.font_size = font_size
63+
self.line_space = line_space
64+
self.replace_new_line = replace_new_line
65+
self.new_line_token = new_line_token
66+
self.width = width
67+
self.height = height
68+
self.block_size = block_size
69+
self.rendered_as_target = rendered_as_target
70+
self.patch_width = patch_width
71+
self.patch_height = patch_height
72+
self.text_mask_rate = text_mask_rate
73+
self.merge_text_masks = merge_text_masks
74+
self.ignore_white_patches = ignore_white_patches
75+
self.add_black_patch = add_black_patch
76+
self.add_prefix = add_prefix
77+
self.autoregressive = autoregressive
78+
self.ar_image_block_size = ar_image_block_size
79+
self.total_block_size = total_block_size
80+
self.context_mask = context_mask
81+
self.image_mode = image_mode
82+
self.sample_mask_at_collator = sample_mask_at_collator
83+
self.mask_ratio = mask_ratio
84+
self.span_masking = span_masking
85+
self.max_span_length = max_span_length
86+
87+
88+
def mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
89+
"""
90+
Text masking
91+
"""
92+
labels = inputs.clone()
93+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
94+
probability_matrix = torch.full(labels.shape, self.text_mask_rate)
95+
if special_tokens_mask is None:
96+
special_tokens_mask = [
97+
self.processor.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
98+
]
99+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
100+
else:
101+
special_tokens_mask = special_tokens_mask.bool()
102+
103+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
104+
masked_indices = torch.bernoulli(probability_matrix).bool()
105+
106+
inputs[masked_indices] = self.processor.tokenizer.mask_token_id
107+
108+
return inputs, labels
109+
110+
111+
def __call__(self, batch):
112+
new_batch = {"flattened_patches": [], "attention_mask": [], "labels": []}
113+
if self.autoregressive:
114+
# Data for autoregressive mode
115+
new_batch["input_ids"] = []
116+
if self.ar_image_block_size == 0:
117+
# Text only
118+
new_batch = {"input_ids": [], "attention_mask": [], "labels": []}
119+
if self.sample_mask_at_collator:
120+
# Sample patch mask in data collator
121+
new_batch["patch_mask"] = []
122+
123+
for item in batch:
124+
if self.autoregressive and self.ar_image_block_size == 0:
125+
# Autoregressive: text only
126+
text_tokens = torch.tensor(item["tokens"].astype(np.int64)).long()
127+
128+
input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), text_tokens], 0)
129+
attention_mask = torch.ones(input_ids.shape).long()
130+
if self.total_block_size is not None:
131+
# Truncate
132+
input_ids = input_ids[:self.total_block_size]
133+
attention_mask = attention_mask[:self.total_block_size]
134+
new_batch["input_ids"].append(input_ids)
135+
new_batch["attention_mask"].append(attention_mask)
136+
labels = input_ids + 0
137+
if self.context_mask is not None:
138+
# Only predict on the non-masked part (mostly for evaluation)
139+
labels[:self.context_mask] = -100
140+
new_batch["labels"].append(labels)
141+
elif self.autoregressive:
142+
# Autoregressive with screenshot
143+
image_tokens = item["tokens"][:self.ar_image_block_size] # render these as screenshots
144+
145+
text = self.processor.decode(image_tokens, skip_special_tokens=True)
146+
if self.replace_new_line:
147+
text = text.replace("\n", self.new_line_token)
148+
149+
if self.add_prefix:
150+
text = "Beginning of the sequence: " + text
151+
152+
image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height)
153+
154+
# In the case where not all text is rendered into the screenshot, we truncate the text
155+
if self.replace_new_line:
156+
_ = rendered_text.replace(self.new_line_token, "\n").rstrip(" ")
157+
else:
158+
_ = rendered_text.rstrip(" ")
159+
encoded_num_img_tokens = len(self.processor(text=_, add_special_tokens=False)['input_ids'])
160+
text_tokens = torch.tensor(item["tokens"][min(encoded_num_img_tokens,self.ar_image_block_size):].astype(np.int64)).long()
161+
encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True)
162+
163+
new_batch["flattened_patches"].append(encoding["flattened_patches"][0])
164+
patch_attention_mask = encoding["attention_mask"][0]
165+
166+
assert not self.add_black_patch # not supported (and not needed with </img>)
167+
168+
# Mask out the attention to ending white patches
169+
if self.ignore_white_patches:
170+
fpatches = new_batch["flattened_patches"][-1][:, 2:]
171+
non_white_patches = ((fpatches - fpatches.mean(dim=-1, keepdim=True)) ** 2 < 1e-6).long().sum(-1) != fpatches.shape[-1]
172+
reverse_non_white_patches = non_white_patches.flip(-1)
173+
non_white_patches = reverse_non_white_patches.nonzero()
174+
if len(non_white_patches) == 0:
175+
first_white_patch = 0
176+
else:
177+
first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0]
178+
179+
patch_attention_mask[first_white_patch:] = 0
180+
181+
# BOS + image + text
182+
input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), encoding["image_input_ids"][0], text_tokens], 0)
183+
attention_mask = torch.ones(input_ids.shape).long()
184+
patch_mask = input_ids == self.processor.patch_token_id
185+
attention_mask[patch_mask] = patch_attention_mask.long()
186+
if self.total_block_size is not None:
187+
input_ids = input_ids[:self.total_block_size]
188+
attention_mask = attention_mask[:self.total_block_size]
189+
new_batch["input_ids"].append(input_ids)
190+
new_batch["attention_mask"].append(attention_mask)
191+
new_batch["labels"].append(input_ids)
192+
193+
else:
194+
if self.text_mask_rate > 0:
195+
input_ids = torch.tensor(item["tokens"].astype(np.int32)).long().unsqueeze(0)
196+
input_ids, labels = self.mask_tokens(input_ids)
197+
input_ids = input_ids.squeeze(0)
198+
labels = labels.squeeze(0)
199+
text = self.processor.decode(input_ids, skip_special_tokens=False)
200+
else:
201+
text = self.processor.decode(item["tokens"], skip_special_tokens=True)
202+
203+
if self.replace_new_line:
204+
text = text.replace("\n", self.new_line_token)
205+
206+
if self.merge_text_masks and self.text_mask_rate > 0:
207+
while True:
208+
if "<mask><mask>" not in text:
209+
break
210+
text = text.replace("<mask><mask>", "<mask>")
211+
212+
if self.add_prefix:
213+
text = "Beginning of the sequence: " + text
214+
215+
image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height)
216+
image = image.convert(self.image_mode)
217+
image = to_numpy_array(image)
218+
if self.image_mode != "RGB":
219+
image = np.expand_dims(image, -1) # h, w, 1
220+
if self.image_mode == "1":
221+
image = image.astype(np.float32) # bool -> float for clf
222+
223+
if self.rendered_as_target:
224+
if self.text_mask_rate > 0:
225+
# this is not very accurate as with the merge masks we can only estimate how much is rendered in the labels
226+
valid_num_tokens = len(self.processor.tokenizer.tokenize(rendered_text))
227+
# consider the merged masks
228+
valid_num_tokens = int(valid_num_tokens / (len(self.processor.tokenizer.tokenize(text)) / len(labels)))
229+
labels[valid_num_tokens:] = self.processor.tokenizer.pad_token_id
230+
else:
231+
labels = self.processor.tokenizer(rendered_text, return_tensors="pt", add_special_tokens=False, max_length=self.block_size, padding="max_length", truncation=True)["input_ids"].squeeze()
232+
233+
encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True)
234+
new_batch["flattened_patches"].append(encoding["flattened_patches"][0])
235+
new_batch["attention_mask"].append(encoding["attention_mask"][0])
236+
new_batch["labels"].append(labels)
237+
238+
if self.add_black_patch:
239+
self.ignore_white_patches
240+
241+
if self.ignore_white_patches:
242+
fpatches = new_batch["flattened_patches"][-1][:, 2:]
243+
# White patches should have all pixels = 1 (normalized)
244+
non_white_patches = (fpatches > 1 - 1e-6).long().sum(-1) != fpatches.shape[-1]
245+
reverse_non_white_patches = non_white_patches.flip(-1)
246+
non_white_patches = reverse_non_white_patches.nonzero()
247+
if len(non_white_patches) == 0:
248+
first_white_patch = 0
249+
else:
250+
first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0]
251+
252+
new_batch["attention_mask"][-1][first_white_patch:] = 0
253+
254+
if self.add_black_patch:
255+
if first_white_patch == len(reverse_non_white_patches):
256+
first_white_patch -= 1 # if there is no white patch, force changing the last one to black
257+
258+
black = 0
259+
new_batch["flattened_patches"][-1][first_white_patch, 2:] = black
260+
new_batch["attention_mask"][-1][first_white_patch] = 1
261+
262+
if self.sample_mask_at_collator:
263+
assert self.span_masking is True # we are only doing this for span masking
264+
seq_length = new_batch["flattened_patches"][-1].shape[0]
265+
len_keep = int(seq_length * (1 - self.mask_ratio))
266+
span_masking_generator = SpanMaskingGenerator(
267+
num_patches=seq_length,
268+
num_masking_patches=seq_length-len_keep,
269+
max_span_length=self.max_span_length,
270+
spacing="span",
271+
cumulative_span_weights=[0.2,0.4,0.6,0.8,0.9,1]
272+
)
273+
patch_mask = torch.tensor(span_masking_generator())
274+
new_batch["patch_mask"].append(patch_mask)
275+
276+
for key in new_batch:
277+
new_batch[key] = torch.stack(new_batch[key])
278+
279+
return new_batch

0 commit comments

Comments
 (0)