Skip to content

Commit 53135d0

Browse files
authored
Merge pull request #120 from VikParuchuri/dev
Fix bugs with RGB
2 parents d167369 + 80e44dd commit 53135d0

File tree

11 files changed

+370
-28
lines changed

11 files changed

+370
-28
lines changed

benchmark/detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from surya.benchmark.metrics import precision_recall
88
from surya.benchmark.tesseract import tesseract_parallel
99
from surya.model.detection.segformer import load_model, load_processor
10-
from surya.input.processing import open_pdf, get_page_images
10+
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
1111
from surya.detection import batch_text_detection
1212
from surya.postprocessing.heatmap import draw_polys_on_image
1313
from surya.postprocessing.util import rescale_bbox
@@ -47,7 +47,7 @@ def main():
4747
# These have already been shuffled randomly, so sampling from the start is fine
4848
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
4949
images = list(dataset["image"])
50-
images = [i.convert("RGB") for i in images]
50+
images = convert_if_not_rgb(images)
5151
correct_boxes = []
5252
for i, boxes in enumerate(dataset["bboxes"]):
5353
img_size = images[i].size

benchmark/layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from surya.benchmark.metrics import precision_recall
77
from surya.detection import batch_text_detection
88
from surya.model.detection.segformer import load_model, load_processor
9-
from surya.input.processing import open_pdf, get_page_images
9+
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
1010
from surya.layout import batch_layout_detection
1111
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
1212
from surya.postprocessing.util import rescale_bbox
@@ -33,7 +33,7 @@ def main():
3333
# These have already been shuffled randomly, so sampling from the start is fine
3434
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
3535
images = list(dataset["image"])
36-
images = [i.convert("RGB") for i in images]
36+
images = convert_if_not_rgb(images)
3737

3838
start = time.time()
3939
line_predictions = batch_text_detection(images, det_model, det_processor)

benchmark/ordering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import json
55

6+
from surya.input.processing import convert_if_not_rgb
67
from surya.model.ordering.model import load_model
78
from surya.model.ordering.processor import load_processor
89
from surya.ordering import batch_ordering
@@ -29,7 +30,7 @@ def main():
2930
split = f"train[:{args.max}]"
3031
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
3132
images = list(dataset["image"])
32-
images = [i.convert("RGB") for i in images]
33+
images = convert_if_not_rgb(images)
3334
bboxes = list(dataset["bboxes"])
3435

3536
start = time.time()

benchmark/recognition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from benchmark.scoring import overlap_score
7+
from surya.input.processing import convert_if_not_rgb
78
from surya.model.recognition.model import load_model as load_recognition_model
89
from surya.model.recognition.processor import load_processor as load_recognition_processor
910
from surya.ocr import run_recognition
@@ -48,7 +49,7 @@ def main():
4849
dataset = dataset.filter(lambda x: x["language"] in langs)
4950

5051
images = list(dataset["image"])
51-
images = [i.convert("RGB") for i in images]
52+
images = convert_if_not_rgb(images)
5253
bboxes = list(dataset["bboxes"])
5354
line_text = list(dataset["text"])
5455
languages = list(dataset["language"])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.4.10"
3+
version = "0.4.11"
44
description = "OCR, layout, reading order, and line detection in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from surya.model.detection.segformer import SegformerForRegressionMask
88
from surya.postprocessing.heatmap import get_and_clean_boxes
99
from surya.postprocessing.affinity import get_vertical_lines
10-
from surya.input.processing import prepare_image_detection, split_image, get_total_splits
10+
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
1111
from surya.schema import TextDetectionResult
1212
from surya.settings import settings
1313
from tqdm import tqdm
@@ -51,7 +51,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor,
5151
all_preds = []
5252
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
5353
batch_image_idxs = batches[batch_idx]
54-
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
54+
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
5555

5656
split_index = []
5757
split_heights = []

surya/input/processing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import random
31
from typing import List
42

53
import cv2
@@ -11,6 +9,15 @@
119
from surya.settings import settings
1210

1311

12+
def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
13+
new_images = []
14+
for image in images:
15+
if image.mode != "RGB":
16+
image = image.convert("RGB")
17+
new_images.append(image)
18+
return new_images
19+
20+
1421
def get_total_splits(image_size, processor):
1522
img_height = list(image_size)[1]
1623
max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
@@ -48,6 +55,8 @@ def split_image(img, processor):
4855
def prepare_image_detection(img, processor):
4956
new_size = (processor.size["width"], processor.size["height"])
5057

58+
# This double resize actually necessary for downstream accuracy
59+
img.thumbnail(new_size, Image.Resampling.LANCZOS)
5160
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
5261

5362
img = np.asarray(img, dtype=np.uint8)

0 commit comments

Comments
 (0)