Skip to content

Commit fa80511

Browse files
committed
DocTR visualizer
1 parent e29ee9a commit fa80511

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

visualize.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import doctr.io
2+
from doctr.models import ocr_predictor
3+
from doctr.io import DocumentFile
4+
import torch
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
from doctr.utils.visualization import visualize_page
8+
import argparse
9+
from PIL import Image
10+
import json
11+
import cv2 as cv
12+
13+
14+
class OCR:
15+
def __init__(self, image):
16+
self.reader = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_mobilenet_v3_large', pretrained=True, detect_orientation=True, paragraph_break=0.015, assume_straight_pages=True).to(torch.device("cuda:0"))
17+
self.image = image
18+
self.results = self.reader(image)
19+
20+
21+
def draw(results: doctr.io.Document, image):
22+
print('drawing')
23+
annotated_img = image.copy()
24+
height, width = results.pages[0].dimensions
25+
print(results.pages[0].blocks)
26+
for block in results.pages[0].blocks:
27+
block_bb = ((int(block.geometry[0][0] * width), int(block.geometry[0][1] * height)),
28+
(int(block.geometry[1][0] * width), int(block.geometry[1][1] * height)))
29+
print(f'drawing block with {block_bb}')
30+
annotated_img = cv.rectangle(annotated_img, block_bb[0], block_bb[1], (255, 0, 0), 3)
31+
for line in block.lines:
32+
line_bb = ((int(line.geometry[0][0] * width ), int(line.geometry[0][1] * height)),
33+
(int(line.geometry[1][0] * width), int(line.geometry[1][1] * height)))
34+
print(f'drawing line with {line_bb}')
35+
annotated_img = cv.rectangle(annotated_img, line_bb[0], line_bb[1], (0, 255, 0), 2)
36+
return annotated_img
37+
38+
39+
if __name__ == '__main__':
40+
parser = argparse.ArgumentParser()
41+
parser.add_argument('image', type=str, help='Path to image')
42+
args = parser.parse_args()
43+
image = DocumentFile.from_images(args.image)
44+
document = np.asarray(Image.open(args.image))
45+
ocr = OCR(image)
46+
with open('ocr_results.json', 'w') as file:
47+
json.dump(ocr.results.export(), file)
48+
cv.imwrite('ocr_draw.jpg', draw(ocr.results, document))
49+
ocr_viz = visualize_page(ocr.results.pages[0].export(), document, words_only=False)
50+
plt.savefig("ocr_visualizer.png")
51+

0 commit comments

Comments
 (0)