Skip to content

Commit 8d8198b

Browse files
fix: DEV-2279: Fix to use get_image_local_path and add docker file (#110)
1 parent 01ca0af commit 8d8198b

File tree

4 files changed

+170
-43
lines changed

4 files changed

+170
-43
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FROM python:3.8-slim
2+
3+
ENV PYTHONUNBUFFERED=True \
4+
PORT=9090
5+
6+
WORKDIR /app
7+
COPY requirements.txt .
8+
9+
RUN pip install --no-cache-dir -r requirements.txt
10+
11+
COPY . ./
12+
13+
CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 --timeout 0 _wsgi:app
+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import argparse
3+
import logging
4+
import logging.config
5+
6+
logging.config.dictConfig({
7+
"version": 1,
8+
"formatters": {
9+
"standard": {
10+
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
11+
}
12+
},
13+
"handlers": {
14+
"console": {
15+
"class": "logging.StreamHandler",
16+
"level": "DEBUG",
17+
"stream": "ext://sys.stdout",
18+
"formatter": "standard"
19+
}
20+
},
21+
"root": {
22+
"level": "ERROR",
23+
"handlers": [
24+
"console"
25+
],
26+
"propagate": True
27+
}
28+
})
29+
30+
from label_studio_ml.api import init_app
31+
from tesseract import BBOXOCR
32+
33+
34+
_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
35+
36+
37+
def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
38+
if not os.path.exists(config_path):
39+
return dict()
40+
with open(config_path) as f:
41+
config = json.load(f)
42+
assert isinstance(config, dict)
43+
return config
44+
45+
46+
if __name__ == "__main__":
47+
parser = argparse.ArgumentParser(description='Label studio')
48+
parser.add_argument(
49+
'-p', '--port', dest='port', type=int, default=9090,
50+
help='Server port')
51+
parser.add_argument(
52+
'--host', dest='host', type=str, default='0.0.0.0',
53+
help='Server host')
54+
parser.add_argument(
55+
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
56+
help='Additional LabelStudioMLBase model initialization kwargs')
57+
parser.add_argument(
58+
'-d', '--debug', dest='debug', action='store_true',
59+
help='Switch debug mode')
60+
parser.add_argument(
61+
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
62+
help='Logging level')
63+
parser.add_argument(
64+
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
65+
help='Directory where models are stored (relative to the project directory)')
66+
parser.add_argument(
67+
'--check', dest='check', action='store_true',
68+
help='Validate model instance before launching server')
69+
70+
args = parser.parse_args()
71+
72+
# setup logging level
73+
if args.log_level:
74+
logging.root.setLevel(args.log_level)
75+
76+
def isfloat(value):
77+
try:
78+
float(value)
79+
return True
80+
except ValueError:
81+
return False
82+
83+
def parse_kwargs():
84+
param = dict()
85+
for k, v in args.kwargs:
86+
if v.isdigit():
87+
param[k] = int(v)
88+
elif v == 'True' or v == 'true':
89+
param[k] = True
90+
elif v == 'False' or v == 'False':
91+
param[k] = False
92+
elif isfloat(v):
93+
param[k] = float(v)
94+
else:
95+
param[k] = v
96+
return param
97+
98+
kwargs = get_kwargs_from_config()
99+
100+
if args.kwargs:
101+
kwargs.update(parse_kwargs())
102+
103+
if args.check:
104+
print('Check "' + BBOXOCR.__name__ + '" instance creation..')
105+
model = BBOXOCR(**kwargs)
106+
107+
app = init_app(
108+
model_class=BBOXOCR,
109+
model_dir=os.environ.get('MODEL_DIR', args.model_dir),
110+
redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
111+
redis_host=os.environ.get('REDIS_HOST', 'localhost'),
112+
redis_port=os.environ.get('REDIS_PORT', 6379),
113+
**kwargs
114+
)
115+
116+
app.run(host=args.host, port=args.port, debug=args.debug)
117+
118+
else:
119+
# for uWSGI use
120+
app = init_app(
121+
model_class=BBOXOCR,
122+
model_dir=os.environ.get('MODEL_DIR', os.path.dirname(__file__)),
123+
redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
124+
redis_host=os.environ.get('REDIS_HOST', 'localhost'),
125+
redis_port=os.environ.get('REDIS_PORT', 6379)
126+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
version: "3.5"
2+
3+
services:
4+
redis:
5+
image: redis:alpine
6+
container_name: redis
7+
hostname: redis
8+
volumes:
9+
- "./data/redis:/data"
10+
expose:
11+
- 6379
12+
server:
13+
container_name: server
14+
build: .
15+
environment:
16+
- MODEL_DIR=/data/models
17+
- RQ_QUEUE_NAME=default
18+
- REDIS_HOST=redis
19+
- REDIS_PORT=6379
20+
ports:
21+
- 9090:9090
22+
depends_on:
23+
- redis
24+
links:
25+
- redis
26+
volumes:
27+
- "./data/server:/data"
28+
- "./logs:/tmp"

label_studio_ml/examples/tesseract/tesseract.py

+3-43
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,13 @@
1-
21
from PIL import Image
32
import pytesseract as pt
4-
import boto3
53
from label_studio_ml.model import LabelStudioMLBase
6-
import pathlib
7-
import os
4+
from utils import get_image_local_path
85
import logging
96

107
logger = logging.getLogger(__name__)
11-
global OCR_config, aws_credentials
8+
global OCR_config
129
OCR_config = "--psm 6"
13-
aws_credentials = {"aws_access_key_id":"",
14-
"aws_secret_access_key":"",
15-
"aws_session_token":""
16-
}
17-
18-
def split_s3_path(s3_path):
19-
path_parts=s3_path.replace("s3://","").split("/")
20-
bucket=path_parts.pop(0)
21-
key="/".join(path_parts)
22-
return bucket, key
2310

24-
def download_S3_file(img_path_url=None, aws_credentials=None):
25-
"""
26-
download image file from S3 and save in ./tmp.{file_extension}
27-
"""
28-
session = boto3.Session(
29-
aws_access_key_id=aws_credentials["aws_access_key_id"],
30-
aws_secret_access_key=aws_credentials["aws_secret_access_key"],
31-
aws_session_token=aws_credentials["aws_session_token"]
32-
)
33-
#Then use the session to get the resource
34-
# s3 = session.resource('s3')
35-
resource = session.resource('s3')
36-
bucket, key = split_s3_path(img_path_url)
37-
file_extension = pathlib.Path(key).suffix
38-
key_basename = "tmp{}".format(file_extension)
39-
my_bucket = resource.Bucket(bucket)
40-
my_bucket.download_file(key, key_basename)
41-
return key_basename
4211

4312
class BBOXOCR(LabelStudioMLBase):
4413
def __init__(self, **kwargs):
@@ -47,28 +16,22 @@ def __init__(self, **kwargs):
4716
def predict(self, tasks, **kwargs):
4817
# extract task meta data: labels, from_name, to_name and other
4918
task = tasks[0]
50-
# print("task", task)
5119
img_path_url = task["data"]["ocr"]
52-
# print("img_path_url", img_path_url)
5320
context = kwargs.get('context')
54-
# print("context", context)
5521
if context:
5622
if not context["result"]:
5723
return []
5824
result = context.get('result')[0]
59-
# print("result", result)
6025
meta = self._extract_meta({**task, **result})
61-
# print("meta", meta)
6226
x = meta["x"]*meta["original_width"]/100
6327
y = meta["y"]*meta["original_height"]/100
6428
w = meta["width"]*meta["original_width"]/100
6529
h = meta["height"]*meta["original_height"]/100
66-
filepath = download_S3_file(img_path_url, aws_credentials)
30+
filepath = get_image_local_path(img_path_url)
6731
IMG = Image.open(filepath)
6832
result_text = pt.image_to_string(IMG.crop((x,y,x+w,y+h)),
6933
config=OCR_config)
7034
meta["text"] = result_text
71-
# print(meta["text"])
7235
temp = {
7336
"original_width": meta["original_width"],
7437
"original_height": meta["original_height"],
@@ -89,7 +52,6 @@ def predict(self, tasks, **kwargs):
8952
"type": "textarea",
9053
"origin": "manual"
9154
}
92-
# print("temp",temp)
9355
return [{
9456
'result': [result, temp],
9557
'score': 0
@@ -105,8 +67,6 @@ def _extract_meta(task):
10567
meta['from_name'] = task['from_name']
10668
meta['to_name'] = task['to_name']
10769
meta['type'] = task['type']
108-
# meta['text'] = task['value']['text']
109-
# meta['data'] = list(task['data'].values())[0]
11070
meta['x'] = task['value']['x']
11171
meta['y'] = task['value']['y']
11272
meta['width'] = task['value']['width']

0 commit comments

Comments
 (0)