Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move deformable detr safe loading code #1055

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions lib/sycamore/sycamore/transforms/detr_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, BinaryIO, Literal, Union, Optional
from pathlib import Path
from itertools import repeat

import requests
import json
from tenacity import retry, retry_if_exception, wait_exponential, stop_after_delay
import base64
from PIL import Image
import fasteners
from pypdf import PdfReader

from sycamore.data import Element, BoundingBox, ImageElement, TableElement
Expand All @@ -34,7 +32,6 @@
from sycamore.transforms.text_extraction.pdf_miner import PdfMinerExtractor

logger = logging.getLogger(__name__)
_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"
_VERSION = "0.2024.07.24"


Expand Down Expand Up @@ -683,18 +680,11 @@ def __init__(self, model_name_or_path, device=None, cache: Optional[Cache] = Non
self._model_name_or_path = model_name_or_path
self.cache = cache

from sycamore.utils.pytorch_dir import get_pytorch_build_directory
from transformers import AutoImageProcessor
from sycamore.utils.model_load import load_deformable_detr

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import AutoImageProcessor, DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("load_model", log_start=True):
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(self._get_device())
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = load_deformable_detr(model_name_or_path, self._get_device())

# Note: We wrap this in a function so that we can execute on both the leader and the workers
# to account for heterogeneous systems. Currently, if you pass in an explicit device parameter
Expand Down
7 changes: 5 additions & 2 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def __init__(self, model: str, device=None):
super().__init__(model, device)

def _init_structure_model(self):
from transformers import DeformableDetrForObjectDetection
from sycamore.utils.model_load import load_deformable_detr

self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device())
self.structure_model = load_deformable_detr(self.model, self._get_device())

def _get_device(self) -> str:
return choose_device(self.device, detr=True)

def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds
return bboxes, scores, labels


def apply_class_thresholds_or_take_best(bboxes, labels, scores, class_names, class_thresholds, epsilon=0.05):
"""
Filter out bounding boxes whose confidence is below the confidence threshold for its
associated class threshold, defining the threshold as whichever is lower between what
is written in the class_thresholds dict and the highest score for the class minus epsilon
"""
new_class_thresholds = {k: v for k, v in class_thresholds.items()}
max_row_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table row")
max_col_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table column")
if max_row_score - epsilon < class_thresholds["table row"]:
new_class_thresholds["table row"] = max_row_score - epsilon
if max_col_score - epsilon < class_thresholds["table column"]:
new_class_thresholds["table column"] = max_col_score - epsilon
new_class_thresholds["table"] = 0.0
return apply_class_thresholds(bboxes, labels, scores, class_names, new_class_thresholds)


def iob(coords1, coords2) -> float:
return BoundingBox(*coords1).iob(BoundingBox(*coords2))

Expand Down Expand Up @@ -83,7 +100,7 @@ def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = Fal
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

if apply_thresholds:
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds(
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds_or_take_best(
pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS
)

Expand Down Expand Up @@ -277,20 +294,32 @@ def slot_into_containers(
# If the container starts after the package ends, break
if not _early_exit_vertical and container["bbox"][0] > package["bbox"][2]:
if len(match_scores) == 0:
match_scores.append({"container": container, "container_num": container_num, "score": 0})
match_scores.append(
{"container": container, "container_num": container_num, "score": 0, "score_2": 0}
)
break
elif _early_exit_vertical and container["bbox"][1] > package["bbox"][3]:
if len(match_scores) == 0:
match_scores.append({"container": container, "container_num": container_num, "score": 0})
match_scores.append(
{"container": container, "container_num": container_num, "score": 0, "score_2": 0}
)
break
container_rect = BoundingBox(*container["bbox"])
intersect_area = container_rect.intersect(package_rect).area
overlap_fraction = intersect_area / package_area
match_scores.append({"container": container, "container_num": container_num, "score": overlap_fraction})
opposite_overlap_fraction = intersect_area / (container_rect.area or 1)
match_scores.append(
{
"container": container,
"container_num": container_num,
"score": overlap_fraction,
"score_2": opposite_overlap_fraction,
}
)

# Don't sort if you don't have to
if unique_assignment:
sorted_match_scores = [max(match_scores, key=lambda x: x["score"])]
sorted_match_scores = [max(match_scores, key=lambda x: (x["score"], x["score_2"]))]
else:
sorted_match_scores = sort_objects_by_score(match_scores)

Expand Down Expand Up @@ -320,7 +349,7 @@ def sort_objects_by_score(objects, reverse=True):
sign = -1
else:
sign = 1
return sorted(objects, key=lambda k: sign * k["score"])
return sorted(objects, key=lambda k: (sign * k["score"], sign * k.get("score_2", 0)))


def remove_objects_without_content(page_spans, objects):
Expand Down Expand Up @@ -911,10 +940,10 @@ def objects_to_structures(objects, tokens, class_thresholds):
if len(tables) == 0:
return {}
if len(tables) > 1:
tables.sort(key=lambda x: x["score"], reverse=True)
tables.sort(key=lambda x: BoundingBox(*x["bbox"]).area, reverse=True)
import logging

logging.warning("Got multiple tables in document. Using only the highest-scoring one")
logging.warning("Got multiple tables in document. Using only the biggest one")

table = tables[0]
structure = {}
Expand Down
33 changes: 33 additions & 0 deletions lib/sycamore/sycamore/utils/model_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sycamore.utils.import_utils import requires_modules
from sycamore.utils.time_trace import LogTime
import fasteners
from pathlib import Path

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import DeformableDetrForObjectDetection

_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"


@requires_modules("transformers", "local_inference")
def load_deformable_detr(model_name_or_path, device) -> "DeformableDetrForObjectDetection":
"""Load deformable detr without getting concurrency issues in
jitc-ing the deformable attention kernel.
Refactored out of:
https://github.com/aryn-ai/sycamore/blob/7e6b62639ce9b8f63d56cb35a32837d1c97e711e/lib/sycamore/sycamore/transforms/detr_partitioner.py#L686
"""
from sycamore.utils.pytorch_dir import get_pytorch_build_directory

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("loading_model", log_start=True):
model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(device)
return model
Loading