From 0f92e7a3e51eaab1f552a244f416c7331d87f0c2 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 17 Dec 2024 16:10:43 -0800 Subject: [PATCH] move deformable detr safe loading code (#1055) * factor around deformable detr loading/lockfile management for use with deformable table extractor Signed-off-by: Henry Lindeman * remove unused global variable Signed-off-by: Henry Lindeman * move .to(device) iniside the lock Signed-off-by: Henry Lindeman * jitpick Signed-off-by: Henry Lindeman * set deformable table extractor choose_device detr=True Signed-off-by: Henry Lindeman * Misc table transformers post-processing (#1077) * misc postprocessing tweaks Signed-off-by: Henry Lindeman * typo Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- .../sycamore/transforms/detr_partitioner.py | 18 ++------ .../transforms/table_structure/extract.py | 7 ++- .../table_structure/table_transformers.py | 45 +++++++++++++++---- lib/sycamore/sycamore/utils/model_load.py | 33 ++++++++++++++ 4 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 lib/sycamore/sycamore/utils/model_load.py diff --git a/lib/sycamore/sycamore/transforms/detr_partitioner.py b/lib/sycamore/sycamore/transforms/detr_partitioner.py index 8b6eccceb..cd00c01c1 100644 --- a/lib/sycamore/sycamore/transforms/detr_partitioner.py +++ b/lib/sycamore/sycamore/transforms/detr_partitioner.py @@ -6,7 +6,6 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, BinaryIO, Literal, Union, Optional -from pathlib import Path from itertools import repeat import requests @@ -14,7 +13,6 @@ from tenacity import retry, retry_if_exception, wait_exponential, stop_after_delay import base64 from PIL import Image -import fasteners from pypdf import PdfReader from sycamore.data import Element, BoundingBox, ImageElement, TableElement @@ -34,7 +32,6 @@ from sycamore.transforms.text_extraction.pdf_miner import PdfMinerExtractor logger = logging.getLogger(__name__) -_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock" _VERSION = "0.2024.07.24" @@ -688,18 +685,11 @@ def __init__(self, model_name_or_path, device=None, cache: Optional[Cache] = Non self._model_name_or_path = model_name_or_path self.cache = cache - from sycamore.utils.pytorch_dir import get_pytorch_build_directory + from transformers import AutoImageProcessor + from sycamore.utils.model_load import load_deformable_detr - with fasteners.InterProcessLock(_DETR_LOCK_FILE): - lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock" - lockfile.unlink(missing_ok=True) - - from transformers import AutoImageProcessor, DeformableDetrForObjectDetection - - LogTime("loading_model", point=True) - with LogTime("load_model", log_start=True): - self.processor = AutoImageProcessor.from_pretrained(model_name_or_path) - self.model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(self._get_device()) + self.processor = AutoImageProcessor.from_pretrained(model_name_or_path) + self.model = load_deformable_detr(model_name_or_path, self._get_device()) # Note: We wrap this in a function so that we can execute on both the leader and the workers # to account for heterogeneous systems. Currently, if you pass in an explicit device parameter diff --git a/lib/sycamore/sycamore/transforms/table_structure/extract.py b/lib/sycamore/sycamore/transforms/table_structure/extract.py index bde614ce7..926695b1e 100644 --- a/lib/sycamore/sycamore/transforms/table_structure/extract.py +++ b/lib/sycamore/sycamore/transforms/table_structure/extract.py @@ -205,9 +205,12 @@ def __init__(self, model: str, device=None): super().__init__(model, device) def _init_structure_model(self): - from transformers import DeformableDetrForObjectDetection + from sycamore.utils.model_load import load_deformable_detr - self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device()) + self.structure_model = load_deformable_detr(self.model, self._get_device()) + + def _get_device(self) -> str: + return choose_device(self.device, detr=True) def extract( self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True diff --git a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py index 6573884cf..881d7d88d 100644 --- a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py +++ b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py @@ -53,6 +53,23 @@ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds return bboxes, scores, labels +def apply_class_thresholds_or_take_best(bboxes, labels, scores, class_names, class_thresholds, epsilon=0.05): + """ + Filter out bounding boxes whose confidence is below the confidence threshold for its + associated class threshold, defining the threshold as whichever is lower between what + is written in the class_thresholds dict and the highest score for the class minus epsilon + """ + new_class_thresholds = {k: v for k, v in class_thresholds.items()} + max_row_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table row") + max_col_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table column") + if max_row_score - epsilon < class_thresholds["table row"]: + new_class_thresholds["table row"] = max_row_score - epsilon + if max_col_score - epsilon < class_thresholds["table column"]: + new_class_thresholds["table column"] = max_col_score - epsilon + new_class_thresholds["table"] = 0.0 + return apply_class_thresholds(bboxes, labels, scores, class_names, new_class_thresholds) + + def iob(coords1, coords2) -> float: return BoundingBox(*coords1).iob(BoundingBox(*coords2)) @@ -83,7 +100,7 @@ def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = Fal pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] if apply_thresholds: - pred_bboxes, pred_scores, pred_labels = apply_class_thresholds( + pred_bboxes, pred_scores, pred_labels = apply_class_thresholds_or_take_best( pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS ) @@ -287,20 +304,32 @@ def slot_into_containers( # If the container starts after the package ends, break if not _early_exit_vertical and container["bbox"][0] > package["bbox"][2]: if len(match_scores) == 0: - match_scores.append({"container": container, "container_num": container_num, "score": 0}) + match_scores.append( + {"container": container, "container_num": container_num, "score": 0, "score_2": 0} + ) break elif _early_exit_vertical and container["bbox"][1] > package["bbox"][3]: if len(match_scores) == 0: - match_scores.append({"container": container, "container_num": container_num, "score": 0}) + match_scores.append( + {"container": container, "container_num": container_num, "score": 0, "score_2": 0} + ) break container_rect = BoundingBox(*container["bbox"]) intersect_area = container_rect.intersect(package_rect).area overlap_fraction = intersect_area / package_area - match_scores.append({"container": container, "container_num": container_num, "score": overlap_fraction}) + opposite_overlap_fraction = intersect_area / (container_rect.area or 1) + match_scores.append( + { + "container": container, + "container_num": container_num, + "score": overlap_fraction, + "score_2": opposite_overlap_fraction, + } + ) # Don't sort if you don't have to if unique_assignment: - sorted_match_scores = [max(match_scores, key=lambda x: x["score"])] + sorted_match_scores = [max(match_scores, key=lambda x: (x["score"], x["score_2"]))] else: sorted_match_scores = sort_objects_by_score(match_scores) @@ -330,7 +359,7 @@ def sort_objects_by_score(objects, reverse=True): sign = -1 else: sign = 1 - return sorted(objects, key=lambda k: sign * k["score"]) + return sorted(objects, key=lambda k: (sign * k["score"], sign * k.get("score_2", 0))) def remove_objects_without_content(page_spans, objects): @@ -921,10 +950,10 @@ def objects_to_structures(objects, tokens, class_thresholds): if len(tables) == 0: return {} if len(tables) > 1: - tables.sort(key=lambda x: x["score"], reverse=True) + tables.sort(key=lambda x: BoundingBox(*x["bbox"]).area, reverse=True) import logging - logging.warning("Got multiple tables in document. Using only the highest-scoring one") + logging.warning("Got multiple tables in document. Using only the biggest one") table = tables[0] structure = {} diff --git a/lib/sycamore/sycamore/utils/model_load.py b/lib/sycamore/sycamore/utils/model_load.py new file mode 100644 index 000000000..e8d05861e --- /dev/null +++ b/lib/sycamore/sycamore/utils/model_load.py @@ -0,0 +1,33 @@ +from sycamore.utils.import_utils import requires_modules +from sycamore.utils.time_trace import LogTime +import fasteners +from pathlib import Path + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import DeformableDetrForObjectDetection + +_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock" + + +@requires_modules("transformers", "local_inference") +def load_deformable_detr(model_name_or_path, device) -> "DeformableDetrForObjectDetection": + """Load deformable detr without getting concurrency issues in + jitc-ing the deformable attention kernel. + + Refactored out of: + https://github.com/aryn-ai/sycamore/blob/7e6b62639ce9b8f63d56cb35a32837d1c97e711e/lib/sycamore/sycamore/transforms/detr_partitioner.py#L686 + """ + from sycamore.utils.pytorch_dir import get_pytorch_build_directory + + with fasteners.InterProcessLock(_DETR_LOCK_FILE): + lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock" + lockfile.unlink(missing_ok=True) + + from transformers import DeformableDetrForObjectDetection + + LogTime("loading_model", point=True) + with LogTime("loading_model", log_start=True): + model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(device) + return model