|
| 1 | + |
| 2 | +from PIL import Image |
| 3 | +import pytesseract as pt |
| 4 | +import boto3 |
| 5 | +from label_studio_ml.model import LabelStudioMLBase |
| 6 | +import pathlib |
| 7 | +import os |
| 8 | +import logging |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | +global OCR_config, aws_credentials |
| 12 | +OCR_config = "--psm 6" |
| 13 | +aws_credentials = {"aws_access_key_id":"", |
| 14 | + "aws_secret_access_key":"", |
| 15 | + "aws_session_token":"" |
| 16 | + } |
| 17 | + |
| 18 | +def split_s3_path(s3_path): |
| 19 | + path_parts=s3_path.replace("s3://","").split("/") |
| 20 | + bucket=path_parts.pop(0) |
| 21 | + key="/".join(path_parts) |
| 22 | + return bucket, key |
| 23 | + |
| 24 | +def download_S3_file(img_path_url=None, aws_credentials=None): |
| 25 | + """ |
| 26 | + download image file from S3 and save in ./tmp.{file_extension} |
| 27 | + """ |
| 28 | + session = boto3.Session( |
| 29 | + aws_access_key_id=aws_credentials["aws_access_key_id"], |
| 30 | + aws_secret_access_key=aws_credentials["aws_secret_access_key"], |
| 31 | + aws_session_token=aws_credentials["aws_session_token"] |
| 32 | + ) |
| 33 | + #Then use the session to get the resource |
| 34 | + # s3 = session.resource('s3') |
| 35 | + resource = session.resource('s3') |
| 36 | + bucket, key = split_s3_path(img_path_url) |
| 37 | + file_extension = pathlib.Path(key).suffix |
| 38 | + key_basename = "tmp{}".format(file_extension) |
| 39 | + my_bucket = resource.Bucket(bucket) |
| 40 | + my_bucket.download_file(key, key_basename) |
| 41 | + return key_basename |
| 42 | + |
| 43 | +class BBOXOCR(LabelStudioMLBase): |
| 44 | + def __init__(self, **kwargs): |
| 45 | + super(BBOXOCR, self).__init__(**kwargs) |
| 46 | + |
| 47 | + def predict(self, tasks, **kwargs): |
| 48 | + # extract task meta data: labels, from_name, to_name and other |
| 49 | + task = tasks[0] |
| 50 | + # print("task", task) |
| 51 | + img_path_url = task["data"]["ocr"] |
| 52 | + # print("img_path_url", img_path_url) |
| 53 | + context = kwargs.get('context') |
| 54 | + # print("context", context) |
| 55 | + if context: |
| 56 | + if not context["result"]: |
| 57 | + return [] |
| 58 | + result = context.get('result')[0] |
| 59 | + # print("result", result) |
| 60 | + meta = self._extract_meta({**task, **result}) |
| 61 | + # print("meta", meta) |
| 62 | + x = meta["x"]*meta["original_width"]/100 |
| 63 | + y = meta["y"]*meta["original_height"]/100 |
| 64 | + w = meta["width"]*meta["original_width"]/100 |
| 65 | + h = meta["height"]*meta["original_height"]/100 |
| 66 | + filepath = download_S3_file(img_path_url, aws_credentials) |
| 67 | + IMG = Image.open(filepath) |
| 68 | + result_text = pt.image_to_string(IMG.crop((x,y,x+w,y+h)), |
| 69 | + config=OCR_config) |
| 70 | + meta["text"] = result_text |
| 71 | + # print(meta["text"]) |
| 72 | + temp = { |
| 73 | + "original_width": meta["original_width"], |
| 74 | + "original_height": meta["original_height"], |
| 75 | + "image_rotation": 0, |
| 76 | + "value": { |
| 77 | + "x": x/meta["original_width"]*100, |
| 78 | + "y": y/meta["original_height"]*100, |
| 79 | + "width": w/meta["original_width"]*100, |
| 80 | + "height": h/meta["original_height"]*100, |
| 81 | + "rotation": 0, |
| 82 | + "text": [ |
| 83 | + meta["text"] |
| 84 | + ] |
| 85 | + }, |
| 86 | + "id": meta["id"], |
| 87 | + "from_name": "transcription", |
| 88 | + "to_name": meta['to_name'], |
| 89 | + "type": "textarea", |
| 90 | + "origin": "manual" |
| 91 | + } |
| 92 | + # print("temp",temp) |
| 93 | + return [{ |
| 94 | + 'result': [result, temp], |
| 95 | + 'score': 0 |
| 96 | + }] |
| 97 | + else: |
| 98 | + return [] |
| 99 | + |
| 100 | + @staticmethod |
| 101 | + def _extract_meta(task): |
| 102 | + meta = dict() |
| 103 | + if task: |
| 104 | + meta['id'] = task['id'] |
| 105 | + meta['from_name'] = task['from_name'] |
| 106 | + meta['to_name'] = task['to_name'] |
| 107 | + meta['type'] = task['type'] |
| 108 | + # meta['text'] = task['value']['text'] |
| 109 | + # meta['data'] = list(task['data'].values())[0] |
| 110 | + meta['x'] = task['value']['x'] |
| 111 | + meta['y'] = task['value']['y'] |
| 112 | + meta['width'] = task['value']['width'] |
| 113 | + meta['height'] = task['value']['height'] |
| 114 | + meta["original_width"] = task['original_width'] |
| 115 | + meta["original_height"] = task['original_height'] |
| 116 | + return meta |
0 commit comments