Skip to content

Commit bcf19d6

Browse files
committed
implement RAG
1 parent e76e447 commit bcf19d6

6 files changed

+64
-7
lines changed

FullTabulation-ICD-11-MMS-zh.xlsx

4.21 MB
Binary file not shown.

ICD-10-ICD-O.xlsx

101 KB
Binary file not shown.

medcodegpt.py

+56-7
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import json
44
import yaml
55
import prompts
6+
import pandas as pd
67
import streamlit as st
78
import streamlit_authenticator as stauth
9+
from icd import SemanticSearch
810
from dotenv import load_dotenv
911
from yaml.loader import SafeLoader
1012
from langchain.prompts import PromptTemplate
@@ -15,8 +17,18 @@
1517

1618

1719
load_dotenv()
18-
with open(os.path.join(os.path.dirname(__file__), 'code_book.txt')) as f:
19-
code_book = f.readlines()
20+
term_df = pd.read_excel('ICD-10-ICD-O.xlsx')
21+
term_df = term_df.loc[~term_df['Coding System'].isin(['ICD-O-3行为学编码', 'ICD-O-3组织学等级和分化程度编码'])]
22+
term_map = dict(zip(term_df['Code'], term_df['释义']))
23+
icd10_semantic_search = SemanticSearch('./vs/icd10')
24+
icdo3_semantic_search = SemanticSearch('./vs/icdo3')
25+
26+
27+
def search_reference0(code):
28+
code_comps = code.split('-')
29+
if len(code_comps) > 1:
30+
code = '-'.join(code.split('-')[1:])
31+
return '\n'.join([line for line in code_book if re.search(re.escape(code), line, re.I) is not None])
2032

2133

2234
def search_reference(code):
@@ -26,27 +38,53 @@ def search_reference(code):
2638
return '\n'.join([line for line in code_book if re.search(re.escape(code), line, re.I) is not None])
2739

2840

29-
def generate(context, chat_llm, callbacks):
41+
def generate(context, chat_llm, callbacks, output_container):
3042
system_message = SystemMessage(content=prompts.prompt1)
31-
initial_user_prompt = PromptTemplate(template=prompts.prompt2, input_variables=['diagnosis']).format(context=context)
43+
related_icd10 = icd10_semantic_search.search(context, k=5)
44+
related_icdo3 = icdo3_semantic_search.search(context, k=5)
45+
related_code_context = '\n'.join([f'{x[0]["Code"]}\n{x[0]["释义"]}' for x in related_icd10 + related_icdo3])
46+
initial_user_prompt = PromptTemplate(template=prompts.prompt2, input_variables=['context', 'related_codes']).format(context=context, related_codes=related_code_context)
3247
initial_user_message = HumanMessage(content=initial_user_prompt)
48+
output_container.chat_message("user").write(system_message.content.replace('\n', '\n\n'))
49+
output_container.chat_message("user").write(initial_user_message.content.replace('\n', '\n\n'))
3350
initial_result = chat_llm([system_message, initial_user_message], callbacks=callbacks)
3451
second_user_message = HumanMessage(content=prompts.prompt3)
52+
output_container.chat_message("user").write(second_user_message.content.replace('\n', '\n\n'))
3553
second_result = chat_llm([system_message, initial_user_message, initial_result, second_user_message], callbacks=callbacks)
3654
code_result = second_result
3755
try_cnt = 0
3856
while True:
3957
format_user_prompt = HumanMessage(content=prompts.prompt4)
58+
output_container.chat_message("user").write(format_user_prompt.content.replace('\n', '\n\n'))
4059
format_result = chat_llm([code_result, format_user_prompt], callbacks=callbacks)
4160
json_text = re.search('```json(.+)```', format_result.content, re.DOTALL)
4261
if json_text is not None:
4362
json_data = json.loads(json_text.group(1))
4463
references = ''
4564
for code in json_data['code'][:3]:
46-
ref = search_reference(code)
47-
references += f'{code}:\n{ref}\n\n'
65+
if re.search('[\::]', code) is not None:
66+
code = re.split('[\::]', code)[-1].strip()
67+
if re.search('^[A-Z]-', code):
68+
code = '-'.join(code.split('-')[1:])
69+
ref = term_map.get(code)
70+
if ref is not None:
71+
references += f'{code}:\n{ref}\n\n'
72+
if re.search('\d\.\-', code):
73+
related_codes = [x for x in term_map.keys() if x.startswith(code.strip('-'))]
74+
for rel in related_codes:
75+
ref = term_map.get(rel)
76+
references += f'{rel}:\n{ref}\n\n'
77+
if re.search('\d\.\d\-\d', code):
78+
code_compo = re.search('(.*\d\.)(\d)\-(\d)', code)
79+
related_codes = [code_compo.group(1) + str(x) for x in
80+
range(int(code_compo.group(2)), int(code_compo.group(3)) + 1)]
81+
related_codes = [x for x in related_codes if x in term_map]
82+
for rel in related_codes:
83+
ref = term_map.get(rel)
84+
references += f'{rel}:\n{ref}\n\n'
4885
refine_user_prompt = PromptTemplate(template=prompts.prompt5, input_variables=['references']).format(references=references)
4986
refine_user_message = HumanMessage(content=refine_user_prompt)
87+
output_container.chat_message("user").write(refine_user_message.content.replace('\n', '\n\n'))
5088
refine_result = chat_llm([system_message, initial_user_message, initial_result, second_user_message, code_result, refine_user_message], callbacks=callbacks)
5189
code_result = refine_result
5290
if '"confirmed": true' in code_result.content:
@@ -55,6 +93,7 @@ def generate(context, chat_llm, callbacks):
5593
if try_cnt > 5:
5694
break
5795
format_user_prompt = HumanMessage(content=prompts.prompt4)
96+
output_container.chat_message("user").write(format_user_prompt.content.replace('\n', '\n\n'))
5897
format_result = chat_llm([code_result, format_user_prompt], callbacks=callbacks)
5998
return format_result.content
6099

@@ -74,8 +113,18 @@ def demo_page():
74113
st_callback = CustomStreamlitCallbackHandler(output_container)
75114
std_callback = StreamingStdOutCallbackHandler()
76115
callbacks = [st_callback, std_callback]
77-
result = generate(raw_input, chat_llm, callbacks=callbacks)
116+
result = generate(raw_input, chat_llm, callbacks, output_container)
78117
st.markdown(result)
118+
st.write("""
119+
<hr style="border: none; border-top: 1px solid #ccc;">
120+
<p style="text-align: center; font-size: 12px;">
121+
沪ICP备18007075号-2
122+
</p>
123+
""", unsafe_allow_html=True)
124+
# st.write("""
125+
# ---
126+
# 沪ICP备18007075号-2
127+
# """)
79128

80129

81130
if __name__ == '__main__':

models/download_models.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from modelscope.hub.snapshot_download import snapshot_download
2+
3+
model_dir = snapshot_download('AI-ModelScope/bge-large-zh-v1.5', cache_dir='./')
4+

prompts.py

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
6060
根据实际情况,肿瘤报告的诊断名称往往不能与手册中的名称完全符合,应选择其中符合程度最高的编码名称。
6161
部分疾病(如部分肝癌、黑色素瘤、间皮瘤和淋巴瘤、白血病等)编码较特殊,本手册已在《ICD-10与ICD-O-3解剖部位编码》中列出了ICD-10与ICD-O-3不同的解剖部位编码和/或对应的ICD-O-3形态学编码;同样在《ICD-O-3形态学编码》中也已经列出了对应的解剖部位编码,请使用者酌情参考。
62+
63+
以下编码可能与目标编码相关,但最终结果不限于以下编码,仅供参考。
64+
{related_codes}
6265
'''
6366

6467
prompt3 = '''

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ python-dotenv==1.0.1
55
PyYAML==6.0.1
66
streamlit==1.32.0
77
streamlit_authenticator==0.3.2
8+
sentence_transformers==3.0.1

0 commit comments

Comments
 (0)