|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import sys |
| 3 | +import os |
| 4 | +sys.path.append('./deep_text_recognition_benchmark') |
| 5 | +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./i-can-read-379204-ce1c5c2f12f5.json" |
| 6 | + |
| 7 | +import uvicorn |
| 8 | +import torch |
| 9 | +import pickle |
| 10 | +import configparser |
| 11 | +import requests |
| 12 | +import torch.backends.cudnn as cudnn |
| 13 | +import torch.utils.data |
| 14 | +import torch.nn.functional as F |
| 15 | +from fastapi import FastAPI, UploadFile |
| 16 | +from transformModel import transform_data |
| 17 | +from preprocessImage import crop_image, preprocess_image |
| 18 | +from deep_text_recognition_benchmark import demo |
| 19 | +from deep_text_recognition_benchmark.model import Model |
| 20 | +from deep_text_recognition_benchmark.utils import CTCLabelConverter, AttnLabelConverter |
| 21 | +from deep_text_recognition_benchmark.dataset import RawDataset, AlignCollate |
| 22 | +# from deep_text_recognition_benchmark.modules import transformation |
| 23 | + |
| 24 | +app = FastAPI(max_request_size=1024*1024*1024) |
| 25 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 26 | +model_path = './saved_models/pretrained/best_accuracy.pth' |
| 27 | + |
| 28 | +dir_name = "words" |
| 29 | +if not os.path.exists(dir_name): |
| 30 | + os.mkdir(dir_name) |
| 31 | + |
| 32 | + |
| 33 | +class Options: |
| 34 | + def __init__(self): |
| 35 | + self.num_fiducial = 20 |
| 36 | + self.batch_max_length = 25 |
| 37 | + self.imgH = 32 |
| 38 | + self.imgW = 100 |
| 39 | + self.rgb = False |
| 40 | + self.input_channel = 1 |
| 41 | + self.output_channel = 512 |
| 42 | + self.hidden_size = 256 |
| 43 | + self.character = '가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘?!.,()' |
| 44 | + self.Transformation = 'TPS' |
| 45 | + self.FeatureExtraction = 'ResNet' |
| 46 | + self.SequenceModeling = 'BiLSTM' |
| 47 | + self.Prediction = 'CTC' |
| 48 | + self.saved_model = model_path |
| 49 | + self.image_folder = "./words" |
| 50 | + self.workers = 0 |
| 51 | + self.batch_size = 64 |
| 52 | + |
| 53 | + |
| 54 | +config = configparser.ConfigParser() |
| 55 | +config.read('config.ini') |
| 56 | + |
| 57 | +host = config.get('server', 'host') |
| 58 | +port = config.getint('server', 'port') |
| 59 | + |
| 60 | + |
| 61 | +@app.get("/") |
| 62 | +async def root(): |
| 63 | + return {"message": f"Server running on {host}:{port}"} |
| 64 | + |
| 65 | + |
| 66 | +@app.post('/api/v1/menu/extract') |
| 67 | +async def extract_text(file: UploadFile): |
| 68 | + results = [] |
| 69 | + menu = preprocess_image(file) |
| 70 | + crop_image(menu) |
| 71 | + |
| 72 | + opt = Options() |
| 73 | + |
| 74 | + if 'CTC' in opt.Prediction: |
| 75 | + converter = CTCLabelConverter(opt.character) |
| 76 | + else: |
| 77 | + converter = AttnLabelConverter(opt.character) |
| 78 | + opt.num_class = len(converter.character) |
| 79 | + |
| 80 | + if opt.rgb: |
| 81 | + opt.input_channel = 3 |
| 82 | + |
| 83 | + model = Model(opt) |
| 84 | + # model = torch.nn.DataParallel(model).to(device) |
| 85 | + # model.load_state_dict(torch.load(opt.saved_model, map_location=device)) |
| 86 | + |
| 87 | + transform_data(model, './saved_models/cafe/transform.pth') |
| 88 | + |
| 89 | + ## Load the .pkl file |
| 90 | + with open('./saved_models/new_model.pkl', 'rb') as f: |
| 91 | + state_dict = pickle.load(f) |
| 92 | + |
| 93 | + ## When I use just .pth file |
| 94 | + # torch.save(model.state_dict(), new_model_path) |
| 95 | + # model.load_state_dict(torch.load(new_model_path)) |
| 96 | + # |
| 97 | + # state_dict = torch.load(new_model_path) |
| 98 | + # new_state_dict = {} |
| 99 | + # for k, v in state_dict.items(): |
| 100 | + # name = k.replace('module.', '') # remove 'module.' from key name |
| 101 | + # new_state_dict[name] = v |
| 102 | + # model.load_state_dict(new_state_dict) |
| 103 | + |
| 104 | + # resize the Prediction.bias tensor to match the size in the checkpoint |
| 105 | + if 'Prediction.bias' in model.state_dict(): |
| 106 | + new_bias_size = model.state_dict()['Prediction.bias'].size() |
| 107 | + old_bias_size = state_dict['Prediction.bias'].size() |
| 108 | + |
| 109 | + if old_bias_size != new_bias_size: |
| 110 | + state_dict['module.Prediction.bias'] = state_dict['module.Prediction.bias'][:new_bias_size[0]] |
| 111 | + model.load_state_dict(state_dict, strict=False) |
| 112 | + |
| 113 | + AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) |
| 114 | + demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset |
| 115 | + demo_loader = torch.utils.data.DataLoader( |
| 116 | + demo_data, batch_size=opt.batch_size, |
| 117 | + shuffle=False, |
| 118 | + num_workers=int(opt.workers), |
| 119 | + collate_fn=AlignCollate_demo, pin_memory=True) |
| 120 | + |
| 121 | + # predict |
| 122 | + model.eval() |
| 123 | + |
| 124 | + with torch.no_grad(): |
| 125 | + for image_tensors, image_path_list in demo_loader: |
| 126 | + batch_size = image_tensors.size(0) |
| 127 | + image = image_tensors.to(device) |
| 128 | + |
| 129 | + # For max length prediction |
| 130 | + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) |
| 131 | + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) |
| 132 | + |
| 133 | + if 'CTC' in opt.Prediction: |
| 134 | + preds = model(image, text_for_pred) |
| 135 | + preds_size = torch.IntTensor([preds.size(1)] * batch_size) |
| 136 | + _, preds_index = preds.max(2) |
| 137 | + preds_str = converter.decode(preds_index, preds_size) |
| 138 | + |
| 139 | + else: |
| 140 | + preds = model(image, text_for_pred, is_train=False) |
| 141 | + _, preds_index = preds.max(2) |
| 142 | + preds_str = converter.decode(preds_index, length_for_pred) |
| 143 | + |
| 144 | + preds_prob = F.softmax(preds, dim=2) |
| 145 | + preds_max_prob, _ = preds_prob.max(dim=2) |
| 146 | + |
| 147 | + for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): |
| 148 | + if 'Attn' in opt.Prediction: |
| 149 | + pred_EOS = pred.find('[s]') |
| 150 | + pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) |
| 151 | + |
| 152 | + result = pred |
| 153 | + # print(f'{img_name:25s}\t{pred:25s}') |
| 154 | + results.append(result) |
| 155 | + print(results) |
| 156 | + return results |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == '__main__': |
| 160 | + uvicorn.run(app, host=host, port=port) |
0 commit comments