-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
196 lines (160 loc) · 7.87 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
wrapper for DocTR end to end OCR
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from math import floor, ceil
import numpy as np
import torch
from clams import ClamsApp, Restifier
from doctr.models import ocr_predictor
from lapps.discriminators import Uri
from mmif import Mmif, View, Annotation, Document, AnnotationTypes, DocumentTypes
from mmif.utils import video_document_helper as vdh
# Imports needed for Clams and MMIF.
# Non-NLP Clams applications will require AnnotationTypes
def rel_coords_to_abs(coords, width, height):
"""
Simple conversion from relative coordinates (percentage) to absolute coordinates (pixel).
Assumes the passed shape is a rectangle, represented by top-left and bottom-right corners,
and compute floor and ceiling based on the geometry.
"""
x1, y1 = coords[0]
x2, y2 = coords[1]
return [(floor(x1 * height), floor(y1 * width)), (ceil(x2 * height), ceil(y2 * width))]
def create_bbox(view: View, coordinates, box_type, time_point):
bbox = view.new_annotation(AnnotationTypes.BoundingBox)
bbox.add_property("coordinates", coordinates)
bbox.add_property("label", box_type)
bbox.add_property("timePoint", time_point)
return bbox
def create_alignment(view: View, source, target) -> None:
alignment = view.new_annotation(AnnotationTypes.Alignment)
alignment.add_property("source", source)
alignment.add_property("target", target)
class DoctrWrapper(ClamsApp):
def __init__(self):
super().__init__()
self.reader = ocr_predictor(det_arch='db_resnet50', reco_arch='parseq',
pretrained=True, detect_orientation=True, paragraph_break=0.035,
assume_straight_pages=True)
if torch.cuda.is_available():
self.gpu = True
self.reader = self.reader.cuda().half()
else:
self.gpu = False
def _appmetadata(self):
# using metadata.py
pass
class LingUnit(object):
"""
A thin wrapper for LAPPS linguistic unit annotations that
represent different geometric levels from DocTR OCR output.
"""
def __init__(self, region: Annotation, document: Document):
self.region = region
self.region.add_property("document", document.id)
self.children = []
def add_child(self, sentence):
self.children.append(sentence)
def collect_targets(self):
self.region.add_property("targets", [child.region.id for child in self.children])
class Token:
"""
Span annotation corresponding to a DocTR Word object. Start and end are character offsets in the text document.
"""
def __init__(self, region: Annotation, document: Document, start: int, end: int):
self.region = region
self.region.add_property("document", document.id)
self.region.add_property("start", start)
self.region.add_property("end", end)
def process_timepoint(self, representative: Annotation, new_view: View, video_doc: Document):
rep_frame_index = vdh.convert(representative.get("timePoint"),
representative.get("timeUnit"), "frame",
video_doc.get("fps"))
image: np.ndarray = vdh.extract_frames_as_images(video_doc, [rep_frame_index], as_PIL=False)[0]
result = self.reader([image])
# assume only one page, as we are passing one image at a time
blocks = result.pages[0].blocks
text_document: Document = new_view.new_textdocument(result.render())
h, w = image.shape[:2]
for block in blocks:
try:
self.process_block(block, new_view, text_document, representative, w, h)
except Exception as e:
self.logger.error(f"Error processing block: {e}")
continue
return text_document, representative
def process_block(self, block, view, text_document, representative, img_width, img_height):
paragraph = self.LingUnit(view.new_annotation(at_type=Uri.PARAGRAPH), text_document)
paragraph_bb = create_bbox(view, rel_coords_to_abs(block.geometry, img_width, img_height), "text", representative.id)
create_alignment(view, paragraph.region.id, paragraph_bb.id)
for line in block.lines:
try:
sentence = self.process_line(line, view, text_document, representative, img_width, img_height)
except Exception as e:
self.logger.error(f"Error processing line: {e}")
continue
paragraph.add_child(sentence)
paragraph.collect_targets()
def process_line(self, line, view, text_document, representative, img_width, img_height):
sentence = self.LingUnit(view.new_annotation(at_type=Uri.SENTENCE), text_document)
sentence_bb = create_bbox(view, rel_coords_to_abs(line.geometry, img_width, img_height), "text", representative.id)
create_alignment(view, sentence.region.id, sentence_bb.id)
for word in line.words:
if word.confidence > 0.4:
start = text_document.text_value.find(word.value)
end = start + len(word.value)
token = self.Token(view.new_annotation(at_type=Uri.TOKEN), text_document, start, end)
token_bb = create_bbox(view, rel_coords_to_abs(word.geometry, img_width, img_height), "text", representative.id)
create_alignment(view, token.region.id, token_bb.id)
sentence.add_child(token)
sentence.collect_targets()
return sentence
def _annotate(self, mmif: Mmif, **parameters) -> Mmif:
if self.gpu:
self.logger.debug("running app on GPU")
else:
self.logger.debug("running app on CPU")
video_doc: Document = mmif.get_documents_by_type(DocumentTypes.VideoDocument)[0]
input_view: View = mmif.get_views_for_document(video_doc.properties.id)[-1]
new_view: View = mmif.new_view()
self.sign_view(new_view, parameters)
with ThreadPoolExecutor() as executor:
futures = []
for timeframe in input_view.get_annotations(AnnotationTypes.TimeFrame):
for rep_id in timeframe.get("representatives"):
if Mmif.id_delimiter not in rep_id:
rep_id = f'{input_view.id}{Mmif.id_delimiter}{rep_id}'
representative = mmif[rep_id]
futures.append(executor.submit(self.process_timepoint, representative, new_view, video_doc))
if len(futures) == 0:
# TODO (krim @ 4/18/24): if "representatives" is not present, process just the middle frame
pass
for future in futures:
try:
text_document, representative = future.result()
self.logger.debug(text_document.get('text'))
create_alignment(new_view, representative.id, text_document.id)
except Exception as e:
self.logger.error(f"Error processing timeframe: {e}")
continue
return mmif
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", action="store", default="5000", help="set port to listen")
parser.add_argument("--production", action="store_true", help="run gunicorn server")
# add more arguments as needed
# parser.add_argument(more_arg...)
parsed_args = parser.parse_args()
# create the app instance
app = DoctrWrapper()
http_app = Restifier(app, port=int(parsed_args.port))
# for running the application in production mode
if parsed_args.production:
http_app.serve_production()
# development mode
else:
app.logger.setLevel(logging.DEBUG)
http_app.run()