|
| 1 | +from torch.utils.data import Dataset |
| 2 | +import numpy as np |
| 3 | +# import gi |
| 4 | +# gi.require_version('Pango', '1.0') |
| 5 | +# gi.require_version('PangoCairo', '1.0') |
| 6 | +# from gi.repository import Pango, PangoCairo |
| 7 | +# import cairo |
| 8 | +from PIL import Image |
| 9 | +from dataclasses import dataclass, field |
| 10 | +import torch |
| 11 | +from streaming import LocalDataset |
| 12 | +from image_utils import * |
| 13 | +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union |
| 14 | +from transformers.image_utils import to_numpy_array |
| 15 | +from modeling.span_masking import SpanMaskingGenerator |
| 16 | +from random import sample |
| 17 | +from image_utils import render_text |
| 18 | + |
| 19 | +class NumpyDataset(Dataset): |
| 20 | + |
| 21 | + def __init__(self, path, block_size=None): |
| 22 | + self.tokens = np.load(path) |
| 23 | + self.block_size = self.tokens.shape[1] if block_size is None else block_size |
| 24 | + self.font_size = None |
| 25 | + |
| 26 | + def __len__(self): |
| 27 | + return len(self.tokens) |
| 28 | + |
| 29 | + def __getitem__(self, idx): |
| 30 | + return {"tokens": self.tokens[idx][:self.block_size], "font_size": self.font_size} |
| 31 | + |
| 32 | + |
| 33 | +class RenderTextCollator: |
| 34 | + def __init__(self, |
| 35 | + processor: object, |
| 36 | + font_size: int, |
| 37 | + line_space: int, |
| 38 | + replace_new_line: bool, |
| 39 | + new_line_token: str, |
| 40 | + width: int, |
| 41 | + height: int, |
| 42 | + block_size: int = 1024, |
| 43 | + rendered_as_target: bool = False, |
| 44 | + patch_width: int = 16, |
| 45 | + patch_height: int = 16, |
| 46 | + text_mask_rate: float = 0, |
| 47 | + merge_text_masks: bool = False, |
| 48 | + ignore_white_patches: bool = False, |
| 49 | + add_black_patch: bool = False, |
| 50 | + add_prefix: bool = False, |
| 51 | + autoregressive: bool = False, |
| 52 | + ar_image_block_size: int = None, |
| 53 | + total_block_size: int = None, |
| 54 | + context_mask: int = None, |
| 55 | + image_mode: str = "RGB", |
| 56 | + sample_mask_at_collator: bool = False, |
| 57 | + mask_ratio: float = 0, |
| 58 | + span_masking: bool = False, |
| 59 | + max_span_length: int = 6, |
| 60 | + ): |
| 61 | + self.processor = processor |
| 62 | + self.font_size = font_size |
| 63 | + self.line_space = line_space |
| 64 | + self.replace_new_line = replace_new_line |
| 65 | + self.new_line_token = new_line_token |
| 66 | + self.width = width |
| 67 | + self.height = height |
| 68 | + self.block_size = block_size |
| 69 | + self.rendered_as_target = rendered_as_target |
| 70 | + self.patch_width = patch_width |
| 71 | + self.patch_height = patch_height |
| 72 | + self.text_mask_rate = text_mask_rate |
| 73 | + self.merge_text_masks = merge_text_masks |
| 74 | + self.ignore_white_patches = ignore_white_patches |
| 75 | + self.add_black_patch = add_black_patch |
| 76 | + self.add_prefix = add_prefix |
| 77 | + self.autoregressive = autoregressive |
| 78 | + self.ar_image_block_size = ar_image_block_size |
| 79 | + self.total_block_size = total_block_size |
| 80 | + self.context_mask = context_mask |
| 81 | + self.image_mode = image_mode |
| 82 | + self.sample_mask_at_collator = sample_mask_at_collator |
| 83 | + self.mask_ratio = mask_ratio |
| 84 | + self.span_masking = span_masking |
| 85 | + self.max_span_length = max_span_length |
| 86 | + |
| 87 | + |
| 88 | + def mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: |
| 89 | + """ |
| 90 | + Text masking |
| 91 | + """ |
| 92 | + labels = inputs.clone() |
| 93 | + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) |
| 94 | + probability_matrix = torch.full(labels.shape, self.text_mask_rate) |
| 95 | + if special_tokens_mask is None: |
| 96 | + special_tokens_mask = [ |
| 97 | + self.processor.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
| 98 | + ] |
| 99 | + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
| 100 | + else: |
| 101 | + special_tokens_mask = special_tokens_mask.bool() |
| 102 | + |
| 103 | + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
| 104 | + masked_indices = torch.bernoulli(probability_matrix).bool() |
| 105 | + |
| 106 | + inputs[masked_indices] = self.processor.tokenizer.mask_token_id |
| 107 | + |
| 108 | + return inputs, labels |
| 109 | + |
| 110 | + |
| 111 | + def __call__(self, batch): |
| 112 | + new_batch = {"flattened_patches": [], "attention_mask": [], "labels": []} |
| 113 | + if self.autoregressive: |
| 114 | + # Data for autoregressive mode |
| 115 | + new_batch["input_ids"] = [] |
| 116 | + if self.ar_image_block_size == 0: |
| 117 | + # Text only |
| 118 | + new_batch = {"input_ids": [], "attention_mask": [], "labels": []} |
| 119 | + if self.sample_mask_at_collator: |
| 120 | + # Sample patch mask in data collator |
| 121 | + new_batch["patch_mask"] = [] |
| 122 | + |
| 123 | + for item in batch: |
| 124 | + if self.autoregressive and self.ar_image_block_size == 0: |
| 125 | + # Autoregressive: text only |
| 126 | + text_tokens = torch.tensor(item["tokens"].astype(np.int64)).long() |
| 127 | + |
| 128 | + input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), text_tokens], 0) |
| 129 | + attention_mask = torch.ones(input_ids.shape).long() |
| 130 | + if self.total_block_size is not None: |
| 131 | + # Truncate |
| 132 | + input_ids = input_ids[:self.total_block_size] |
| 133 | + attention_mask = attention_mask[:self.total_block_size] |
| 134 | + new_batch["input_ids"].append(input_ids) |
| 135 | + new_batch["attention_mask"].append(attention_mask) |
| 136 | + labels = input_ids + 0 |
| 137 | + if self.context_mask is not None: |
| 138 | + # Only predict on the non-masked part (mostly for evaluation) |
| 139 | + labels[:self.context_mask] = -100 |
| 140 | + new_batch["labels"].append(labels) |
| 141 | + elif self.autoregressive: |
| 142 | + # Autoregressive with screenshot |
| 143 | + image_tokens = item["tokens"][:self.ar_image_block_size] # render these as screenshots |
| 144 | + |
| 145 | + text = self.processor.decode(image_tokens, skip_special_tokens=True) |
| 146 | + if self.replace_new_line: |
| 147 | + text = text.replace("\n", self.new_line_token) |
| 148 | + |
| 149 | + if self.add_prefix: |
| 150 | + text = "Beginning of the sequence: " + text |
| 151 | + |
| 152 | + image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height) |
| 153 | + |
| 154 | + # In the case where not all text is rendered into the screenshot, we truncate the text |
| 155 | + if self.replace_new_line: |
| 156 | + _ = rendered_text.replace(self.new_line_token, "\n").rstrip(" ") |
| 157 | + else: |
| 158 | + _ = rendered_text.rstrip(" ") |
| 159 | + encoded_num_img_tokens = len(self.processor(text=_, add_special_tokens=False)['input_ids']) |
| 160 | + text_tokens = torch.tensor(item["tokens"][min(encoded_num_img_tokens,self.ar_image_block_size):].astype(np.int64)).long() |
| 161 | + encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True) |
| 162 | + |
| 163 | + new_batch["flattened_patches"].append(encoding["flattened_patches"][0]) |
| 164 | + patch_attention_mask = encoding["attention_mask"][0] |
| 165 | + |
| 166 | + assert not self.add_black_patch # not supported (and not needed with </img>) |
| 167 | + |
| 168 | + # Mask out the attention to ending white patches |
| 169 | + if self.ignore_white_patches: |
| 170 | + fpatches = new_batch["flattened_patches"][-1][:, 2:] |
| 171 | + non_white_patches = ((fpatches - fpatches.mean(dim=-1, keepdim=True)) ** 2 < 1e-6).long().sum(-1) != fpatches.shape[-1] |
| 172 | + reverse_non_white_patches = non_white_patches.flip(-1) |
| 173 | + non_white_patches = reverse_non_white_patches.nonzero() |
| 174 | + if len(non_white_patches) == 0: |
| 175 | + first_white_patch = 0 |
| 176 | + else: |
| 177 | + first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0] |
| 178 | + |
| 179 | + patch_attention_mask[first_white_patch:] = 0 |
| 180 | + |
| 181 | + # BOS + image + text |
| 182 | + input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), encoding["image_input_ids"][0], text_tokens], 0) |
| 183 | + attention_mask = torch.ones(input_ids.shape).long() |
| 184 | + patch_mask = input_ids == self.processor.patch_token_id |
| 185 | + attention_mask[patch_mask] = patch_attention_mask.long() |
| 186 | + if self.total_block_size is not None: |
| 187 | + input_ids = input_ids[:self.total_block_size] |
| 188 | + attention_mask = attention_mask[:self.total_block_size] |
| 189 | + new_batch["input_ids"].append(input_ids) |
| 190 | + new_batch["attention_mask"].append(attention_mask) |
| 191 | + new_batch["labels"].append(input_ids) |
| 192 | + |
| 193 | + else: |
| 194 | + if self.text_mask_rate > 0: |
| 195 | + input_ids = torch.tensor(item["tokens"].astype(np.int32)).long().unsqueeze(0) |
| 196 | + input_ids, labels = self.mask_tokens(input_ids) |
| 197 | + input_ids = input_ids.squeeze(0) |
| 198 | + labels = labels.squeeze(0) |
| 199 | + text = self.processor.decode(input_ids, skip_special_tokens=False) |
| 200 | + else: |
| 201 | + text = self.processor.decode(item["tokens"], skip_special_tokens=True) |
| 202 | + |
| 203 | + if self.replace_new_line: |
| 204 | + text = text.replace("\n", self.new_line_token) |
| 205 | + |
| 206 | + if self.merge_text_masks and self.text_mask_rate > 0: |
| 207 | + while True: |
| 208 | + if "<mask><mask>" not in text: |
| 209 | + break |
| 210 | + text = text.replace("<mask><mask>", "<mask>") |
| 211 | + |
| 212 | + if self.add_prefix: |
| 213 | + text = "Beginning of the sequence: " + text |
| 214 | + |
| 215 | + image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height) |
| 216 | + image = image.convert(self.image_mode) |
| 217 | + image = to_numpy_array(image) |
| 218 | + if self.image_mode != "RGB": |
| 219 | + image = np.expand_dims(image, -1) # h, w, 1 |
| 220 | + if self.image_mode == "1": |
| 221 | + image = image.astype(np.float32) # bool -> float for clf |
| 222 | + |
| 223 | + if self.rendered_as_target: |
| 224 | + if self.text_mask_rate > 0: |
| 225 | + # this is not very accurate as with the merge masks we can only estimate how much is rendered in the labels |
| 226 | + valid_num_tokens = len(self.processor.tokenizer.tokenize(rendered_text)) |
| 227 | + # consider the merged masks |
| 228 | + valid_num_tokens = int(valid_num_tokens / (len(self.processor.tokenizer.tokenize(text)) / len(labels))) |
| 229 | + labels[valid_num_tokens:] = self.processor.tokenizer.pad_token_id |
| 230 | + else: |
| 231 | + labels = self.processor.tokenizer(rendered_text, return_tensors="pt", add_special_tokens=False, max_length=self.block_size, padding="max_length", truncation=True)["input_ids"].squeeze() |
| 232 | + |
| 233 | + encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True) |
| 234 | + new_batch["flattened_patches"].append(encoding["flattened_patches"][0]) |
| 235 | + new_batch["attention_mask"].append(encoding["attention_mask"][0]) |
| 236 | + new_batch["labels"].append(labels) |
| 237 | + |
| 238 | + if self.add_black_patch: |
| 239 | + self.ignore_white_patches |
| 240 | + |
| 241 | + if self.ignore_white_patches: |
| 242 | + fpatches = new_batch["flattened_patches"][-1][:, 2:] |
| 243 | + # White patches should have all pixels = 1 (normalized) |
| 244 | + non_white_patches = (fpatches > 1 - 1e-6).long().sum(-1) != fpatches.shape[-1] |
| 245 | + reverse_non_white_patches = non_white_patches.flip(-1) |
| 246 | + non_white_patches = reverse_non_white_patches.nonzero() |
| 247 | + if len(non_white_patches) == 0: |
| 248 | + first_white_patch = 0 |
| 249 | + else: |
| 250 | + first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0] |
| 251 | + |
| 252 | + new_batch["attention_mask"][-1][first_white_patch:] = 0 |
| 253 | + |
| 254 | + if self.add_black_patch: |
| 255 | + if first_white_patch == len(reverse_non_white_patches): |
| 256 | + first_white_patch -= 1 # if there is no white patch, force changing the last one to black |
| 257 | + |
| 258 | + black = 0 |
| 259 | + new_batch["flattened_patches"][-1][first_white_patch, 2:] = black |
| 260 | + new_batch["attention_mask"][-1][first_white_patch] = 1 |
| 261 | + |
| 262 | + if self.sample_mask_at_collator: |
| 263 | + assert self.span_masking is True # we are only doing this for span masking |
| 264 | + seq_length = new_batch["flattened_patches"][-1].shape[0] |
| 265 | + len_keep = int(seq_length * (1 - self.mask_ratio)) |
| 266 | + span_masking_generator = SpanMaskingGenerator( |
| 267 | + num_patches=seq_length, |
| 268 | + num_masking_patches=seq_length-len_keep, |
| 269 | + max_span_length=self.max_span_length, |
| 270 | + spacing="span", |
| 271 | + cumulative_span_weights=[0.2,0.4,0.6,0.8,0.9,1] |
| 272 | + ) |
| 273 | + patch_mask = torch.tensor(span_masking_generator()) |
| 274 | + new_batch["patch_mask"].append(patch_mask) |
| 275 | + |
| 276 | + for key in new_batch: |
| 277 | + new_batch[key] = torch.stack(new_batch[key]) |
| 278 | + |
| 279 | + return new_batch |
0 commit comments