Skip to content

Commit f60fee6

Browse files
committed
build icd kb
1 parent 7d035d7 commit f60fee6

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

icd.py

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
import re
3+
import torch
4+
import pickle
5+
import requests
6+
import numpy as np
7+
import pandas as pd
8+
from tqdm import tqdm
9+
from dotenv import load_dotenv
10+
from langchain_community.vectorstores import FAISS
11+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
12+
from langchain.text_splitter import RecursiveCharacterTextSplitter
13+
14+
load_dotenv()
15+
16+
17+
def get_token():
18+
token_endpoint = 'https://icdaccessmanagement.who.int/connect/token'
19+
client_id = os.environ['ICD_CLIENT_ID']
20+
client_secret = os.environ['ICD_CLIENT_SECRET']
21+
scope = 'icdapi_access'
22+
grant_type = 'client_credentials'
23+
24+
# get the OAUTH2 token
25+
26+
# set data to post
27+
payload = {'client_id': client_id,
28+
'client_secret': client_secret,
29+
'scope': scope,
30+
'grant_type': grant_type}
31+
32+
# make request
33+
r = requests.post(token_endpoint, data=payload, verify=False).json()
34+
token = r['access_token']
35+
return token
36+
37+
38+
def augment_icd_info():
39+
icd_tabulation_df = pd.read_excel('SimpleTabulation-ICD-11-MMS-zh.xlsx')
40+
leaf_terms = icd_tabulation_df.loc[icd_tabulation_df['isLeaf']==True & pd.notnull(icd_tabulation_df['Foundation URI']), :]
41+
uris = leaf_terms['Foundation URI'].tolist()
42+
token = get_token()
43+
results = dict()
44+
# with open('icd11.pkl', 'rb') as f:
45+
# results = pickle.load(f)
46+
for uri in tqdm(uris):
47+
if uri in results:
48+
continue
49+
for _ in range(3):
50+
try:
51+
headers = {'Authorization': 'Bearer ' + token,
52+
'Accept': 'application/json',
53+
'Accept-Language': 'zh',
54+
'API-Version': 'v2'}
55+
r = requests.get(uri, headers=headers, verify=False)
56+
data = r.json()
57+
results[uri] = data
58+
with open('icd11.pkl', 'wb') as f:
59+
pickle.dump(results, f)
60+
break
61+
except:
62+
token = get_token()
63+
icd_tabulation_df['full_name'] = icd_tabulation_df['Foundation URI'].apply(lambda x: details.get(x, {}).get('fullySpecifiedName', {}).get('@value', ''))
64+
icd_tabulation_df['definition'] = icd_tabulation_df['Foundation URI'].apply(lambda x: details.get(x, {}).get('definition', {}).get('@value', ''))
65+
icd_tabulation_df['synonym'] = icd_tabulation_df['Foundation URI'].apply(lambda x: '|'.join([i['label']['@value'] for i in details.get(x, {}).get('synonym', [])]))
66+
icd_tabulation_df.to_excel('FullTabulation-ICD-11-MMS-zh.xlsx', index=False)
67+
term_df = icd_tabulation_df.loc[icd_tabulation_df['isLeaf']==True, :]
68+
term_df.to_excel('TermTabulation-ICD-11-MMS-zh.xlsx', index=False)
69+
70+
71+
def build_vs(text_list, meta_list, vs_path, chunk_size=500, chunk_overlap=50, batch_size=100):
72+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
73+
embeddings = HuggingFaceBgeEmbeddings(model_name='./models/AI-ModelScope/bge-large-zh-v1.5',
74+
model_kwargs={'device': device})
75+
text_splitter = RecursiveCharacterTextSplitter(
76+
chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=False)
77+
docs = text_splitter.create_documents(text_list, metadatas=meta_list)
78+
text_embeddings = list()
79+
for i in tqdm(range(int(np.ceil(len(docs) / batch_size))), desc='Embedding'):
80+
embeds = embeddings.embed_documents([x.page_content for x in docs[i * batch_size: (i + 1) * batch_size]])
81+
text_embeddings.append(embeds)
82+
text_embedding_pairs = list(zip([x.page_content for x in docs], np.concatenate(text_embeddings, axis=0)))
83+
vector_store = FAISS.from_embeddings(text_embedding_pairs, embeddings, [x.metadata for x in docs])
84+
vector_store.save_local(vs_path)
85+
86+
87+
def create_kb():
88+
term_df = pd.read_excel('TermTabulation-ICD-11-MMS-zh.xlsx')
89+
term_df.fillna('', inplace=True)
90+
term_df['names'] = term_df.apply(lambda x: [re.sub('^(\-\s)+', '', x['TitleEN']),
91+
re.sub('^(\-\s)+', '', x['Title']),
92+
x['full_name'],
93+
x['synonym']],
94+
axis=1)
95+
term_df['title'] = term_df.apply(lambda x: re.sub('^(\-\s)+', '', x['Title']) if len(x['Title'])>0 else re.sub('^(\-\s)+', '', x['TitleEN']),
96+
axis=1)
97+
term_df['names'] = term_df['names'].apply(lambda x: list(set([i for i in '|'.join(x).split('|') if len(i) > 0])))
98+
term_df['description'] = term_df.apply(lambda x: [x['definition']] + x['names'] if len(x['definition'])>0 else x['names'], axis=1)
99+
term_df['meta'] = term_df.apply(lambda x: {'Code': x['Code'], 'Title': x['title']}, axis=1)
100+
text_list = term_df['title'].tolist()
101+
meta_list = term_df['meta'].tolist()
102+
build_vs(text_list, meta_list, './vs/title')
103+
term_df['names'] = term_df['names'].apply(lambda x: '\n'.join(x))
104+
text_list = term_df['names'].tolist()
105+
build_vs(text_list, meta_list, './vs/names')
106+
term_df['description'] = term_df['description'].apply(lambda x: '\n'.join(x))
107+
text_list = term_df['description'].tolist()
108+
build_vs(text_list, meta_list, './vs/description')
109+
110+
111+
def create_icd10_kb():
112+
term_df = pd.read_excel('ICD-10-ICD-O.xlsx')
113+
term_df = term_df.loc[~term_df['Coding System'].isin(['ICD-O-3行为学编码', 'ICD-O-3组织学等级和分化程度编码']), ['Coding System', 'Code', '释义']]
114+
term_df['meta'] = term_df.apply(lambda x: {'Coding System': x['Coding System'],
115+
'Code': x['Code'],
116+
'释义': x['释义']}, axis=1)
117+
icd10_term_df = term_df[term_df['Coding System'].isin(['ICD10', 'ICD10-特殊疾病类别'])]
118+
icdo3_term_df = term_df[term_df['Coding System'].isin(['ICD-O-3形态学编码', 'ICD-O-3解剖部位编码'])]
119+
text_list = icd10_term_df['释义'].tolist()
120+
meta_list = icd10_term_df['meta'].tolist()
121+
build_vs(text_list, meta_list, './vs/icd10')
122+
text_list = icdo3_term_df['释义'].tolist()
123+
meta_list = icdo3_term_df['meta'].tolist()
124+
build_vs(text_list, meta_list, './vs/icdo3')
125+
126+
127+
class SemanticSearch:
128+
def __init__(self, vs_path):
129+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
130+
embeddings = HuggingFaceBgeEmbeddings(model_name='./models/AI-ModelScope/bge-large-zh-v1.5',
131+
model_kwargs={'device': device})
132+
self.vector_store = FAISS.load_local(vs_path, embeddings, allow_dangerous_deserialization=True)
133+
134+
def search(self, question, k=10, titles=None):
135+
if titles is None:
136+
related_docs_with_score = self.vector_store.similarity_search_with_score(question, k=k)
137+
else:
138+
related_docs_with_score = self.vector_store.similarity_search_with_score(
139+
question, filter={'title': titles}, k=k, fetch_k=len(self.vector_store.index_to_docstore_id))
140+
related_docs = [(doc[0].metadata, doc[0].page_content) for doc in related_docs_with_score]
141+
return related_docs
142+
143+
144+
if __name__ == '__main__':
145+
# token = get_token()
146+
# get_entity(token, '257068234')
147+
# augment_icd_info()
148+
# create_kb()
149+
150+
# text = '结合免疫组化及前次基因重排检测结果诊断:(肝肿块)淋巴组织增生性病变,考虑为黏膜相关淋巴组织结外边缘区B细胞淋巴瘤,伴肝门部淋巴结转移;慢性肝血吸虫病;慢性胆囊炎。'
151+
# semantic_search = SemanticSearch('./vs/title')
152+
# a = semantic_search.search(text)
153+
# semantic_search = SemanticSearch('./vs/names')
154+
# b = semantic_search.search(text)
155+
# semantic_search = SemanticSearch('./vs/description')
156+
# c = semantic_search.search(text)
157+
# print()
158+
159+
create_icd10_kb()

0 commit comments

Comments
 (0)