|
| 1 | +import os |
| 2 | +import re |
| 3 | +import torch |
| 4 | +import pickle |
| 5 | +import requests |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | +from tqdm import tqdm |
| 9 | +from dotenv import load_dotenv |
| 10 | +from langchain_community.vectorstores import FAISS |
| 11 | +from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
| 12 | +from langchain.text_splitter import RecursiveCharacterTextSplitter |
| 13 | + |
| 14 | +load_dotenv() |
| 15 | + |
| 16 | + |
| 17 | +def get_token(): |
| 18 | + token_endpoint = 'https://icdaccessmanagement.who.int/connect/token' |
| 19 | + client_id = os.environ['ICD_CLIENT_ID'] |
| 20 | + client_secret = os.environ['ICD_CLIENT_SECRET'] |
| 21 | + scope = 'icdapi_access' |
| 22 | + grant_type = 'client_credentials' |
| 23 | + |
| 24 | + # get the OAUTH2 token |
| 25 | + |
| 26 | + # set data to post |
| 27 | + payload = {'client_id': client_id, |
| 28 | + 'client_secret': client_secret, |
| 29 | + 'scope': scope, |
| 30 | + 'grant_type': grant_type} |
| 31 | + |
| 32 | + # make request |
| 33 | + r = requests.post(token_endpoint, data=payload, verify=False).json() |
| 34 | + token = r['access_token'] |
| 35 | + return token |
| 36 | + |
| 37 | + |
| 38 | +def augment_icd_info(): |
| 39 | + icd_tabulation_df = pd.read_excel('SimpleTabulation-ICD-11-MMS-zh.xlsx') |
| 40 | + leaf_terms = icd_tabulation_df.loc[icd_tabulation_df['isLeaf']==True & pd.notnull(icd_tabulation_df['Foundation URI']), :] |
| 41 | + uris = leaf_terms['Foundation URI'].tolist() |
| 42 | + token = get_token() |
| 43 | + results = dict() |
| 44 | + # with open('icd11.pkl', 'rb') as f: |
| 45 | + # results = pickle.load(f) |
| 46 | + for uri in tqdm(uris): |
| 47 | + if uri in results: |
| 48 | + continue |
| 49 | + for _ in range(3): |
| 50 | + try: |
| 51 | + headers = {'Authorization': 'Bearer ' + token, |
| 52 | + 'Accept': 'application/json', |
| 53 | + 'Accept-Language': 'zh', |
| 54 | + 'API-Version': 'v2'} |
| 55 | + r = requests.get(uri, headers=headers, verify=False) |
| 56 | + data = r.json() |
| 57 | + results[uri] = data |
| 58 | + with open('icd11.pkl', 'wb') as f: |
| 59 | + pickle.dump(results, f) |
| 60 | + break |
| 61 | + except: |
| 62 | + token = get_token() |
| 63 | + icd_tabulation_df['full_name'] = icd_tabulation_df['Foundation URI'].apply(lambda x: details.get(x, {}).get('fullySpecifiedName', {}).get('@value', '')) |
| 64 | + icd_tabulation_df['definition'] = icd_tabulation_df['Foundation URI'].apply(lambda x: details.get(x, {}).get('definition', {}).get('@value', '')) |
| 65 | + icd_tabulation_df['synonym'] = icd_tabulation_df['Foundation URI'].apply(lambda x: '|'.join([i['label']['@value'] for i in details.get(x, {}).get('synonym', [])])) |
| 66 | + icd_tabulation_df.to_excel('FullTabulation-ICD-11-MMS-zh.xlsx', index=False) |
| 67 | + term_df = icd_tabulation_df.loc[icd_tabulation_df['isLeaf']==True, :] |
| 68 | + term_df.to_excel('TermTabulation-ICD-11-MMS-zh.xlsx', index=False) |
| 69 | + |
| 70 | + |
| 71 | +def build_vs(text_list, meta_list, vs_path, chunk_size=500, chunk_overlap=50, batch_size=100): |
| 72 | + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
| 73 | + embeddings = HuggingFaceBgeEmbeddings(model_name='./models/AI-ModelScope/bge-large-zh-v1.5', |
| 74 | + model_kwargs={'device': device}) |
| 75 | + text_splitter = RecursiveCharacterTextSplitter( |
| 76 | + chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=False) |
| 77 | + docs = text_splitter.create_documents(text_list, metadatas=meta_list) |
| 78 | + text_embeddings = list() |
| 79 | + for i in tqdm(range(int(np.ceil(len(docs) / batch_size))), desc='Embedding'): |
| 80 | + embeds = embeddings.embed_documents([x.page_content for x in docs[i * batch_size: (i + 1) * batch_size]]) |
| 81 | + text_embeddings.append(embeds) |
| 82 | + text_embedding_pairs = list(zip([x.page_content for x in docs], np.concatenate(text_embeddings, axis=0))) |
| 83 | + vector_store = FAISS.from_embeddings(text_embedding_pairs, embeddings, [x.metadata for x in docs]) |
| 84 | + vector_store.save_local(vs_path) |
| 85 | + |
| 86 | + |
| 87 | +def create_kb(): |
| 88 | + term_df = pd.read_excel('TermTabulation-ICD-11-MMS-zh.xlsx') |
| 89 | + term_df.fillna('', inplace=True) |
| 90 | + term_df['names'] = term_df.apply(lambda x: [re.sub('^(\-\s)+', '', x['TitleEN']), |
| 91 | + re.sub('^(\-\s)+', '', x['Title']), |
| 92 | + x['full_name'], |
| 93 | + x['synonym']], |
| 94 | + axis=1) |
| 95 | + term_df['title'] = term_df.apply(lambda x: re.sub('^(\-\s)+', '', x['Title']) if len(x['Title'])>0 else re.sub('^(\-\s)+', '', x['TitleEN']), |
| 96 | + axis=1) |
| 97 | + term_df['names'] = term_df['names'].apply(lambda x: list(set([i for i in '|'.join(x).split('|') if len(i) > 0]))) |
| 98 | + term_df['description'] = term_df.apply(lambda x: [x['definition']] + x['names'] if len(x['definition'])>0 else x['names'], axis=1) |
| 99 | + term_df['meta'] = term_df.apply(lambda x: {'Code': x['Code'], 'Title': x['title']}, axis=1) |
| 100 | + text_list = term_df['title'].tolist() |
| 101 | + meta_list = term_df['meta'].tolist() |
| 102 | + build_vs(text_list, meta_list, './vs/title') |
| 103 | + term_df['names'] = term_df['names'].apply(lambda x: '\n'.join(x)) |
| 104 | + text_list = term_df['names'].tolist() |
| 105 | + build_vs(text_list, meta_list, './vs/names') |
| 106 | + term_df['description'] = term_df['description'].apply(lambda x: '\n'.join(x)) |
| 107 | + text_list = term_df['description'].tolist() |
| 108 | + build_vs(text_list, meta_list, './vs/description') |
| 109 | + |
| 110 | + |
| 111 | +def create_icd10_kb(): |
| 112 | + term_df = pd.read_excel('ICD-10-ICD-O.xlsx') |
| 113 | + term_df = term_df.loc[~term_df['Coding System'].isin(['ICD-O-3行为学编码', 'ICD-O-3组织学等级和分化程度编码']), ['Coding System', 'Code', '释义']] |
| 114 | + term_df['meta'] = term_df.apply(lambda x: {'Coding System': x['Coding System'], |
| 115 | + 'Code': x['Code'], |
| 116 | + '释义': x['释义']}, axis=1) |
| 117 | + icd10_term_df = term_df[term_df['Coding System'].isin(['ICD10', 'ICD10-特殊疾病类别'])] |
| 118 | + icdo3_term_df = term_df[term_df['Coding System'].isin(['ICD-O-3形态学编码', 'ICD-O-3解剖部位编码'])] |
| 119 | + text_list = icd10_term_df['释义'].tolist() |
| 120 | + meta_list = icd10_term_df['meta'].tolist() |
| 121 | + build_vs(text_list, meta_list, './vs/icd10') |
| 122 | + text_list = icdo3_term_df['释义'].tolist() |
| 123 | + meta_list = icdo3_term_df['meta'].tolist() |
| 124 | + build_vs(text_list, meta_list, './vs/icdo3') |
| 125 | + |
| 126 | + |
| 127 | +class SemanticSearch: |
| 128 | + def __init__(self, vs_path): |
| 129 | + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
| 130 | + embeddings = HuggingFaceBgeEmbeddings(model_name='./models/AI-ModelScope/bge-large-zh-v1.5', |
| 131 | + model_kwargs={'device': device}) |
| 132 | + self.vector_store = FAISS.load_local(vs_path, embeddings, allow_dangerous_deserialization=True) |
| 133 | + |
| 134 | + def search(self, question, k=10, titles=None): |
| 135 | + if titles is None: |
| 136 | + related_docs_with_score = self.vector_store.similarity_search_with_score(question, k=k) |
| 137 | + else: |
| 138 | + related_docs_with_score = self.vector_store.similarity_search_with_score( |
| 139 | + question, filter={'title': titles}, k=k, fetch_k=len(self.vector_store.index_to_docstore_id)) |
| 140 | + related_docs = [(doc[0].metadata, doc[0].page_content) for doc in related_docs_with_score] |
| 141 | + return related_docs |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == '__main__': |
| 145 | + # token = get_token() |
| 146 | + # get_entity(token, '257068234') |
| 147 | + # augment_icd_info() |
| 148 | + # create_kb() |
| 149 | + |
| 150 | + # text = '结合免疫组化及前次基因重排检测结果诊断:(肝肿块)淋巴组织增生性病变,考虑为黏膜相关淋巴组织结外边缘区B细胞淋巴瘤,伴肝门部淋巴结转移;慢性肝血吸虫病;慢性胆囊炎。' |
| 151 | + # semantic_search = SemanticSearch('./vs/title') |
| 152 | + # a = semantic_search.search(text) |
| 153 | + # semantic_search = SemanticSearch('./vs/names') |
| 154 | + # b = semantic_search.search(text) |
| 155 | + # semantic_search = SemanticSearch('./vs/description') |
| 156 | + # c = semantic_search.search(text) |
| 157 | + # print() |
| 158 | + |
| 159 | + create_icd10_kb() |
0 commit comments