3
3
import json
4
4
import yaml
5
5
import prompts
6
+ import pandas as pd
6
7
import streamlit as st
7
8
import streamlit_authenticator as stauth
9
+ from icd import SemanticSearch
8
10
from dotenv import load_dotenv
9
11
from yaml .loader import SafeLoader
10
12
from langchain .prompts import PromptTemplate
15
17
16
18
17
19
load_dotenv ()
18
- with open (os .path .join (os .path .dirname (__file__ ), 'code_book.txt' )) as f :
19
- code_book = f .readlines ()
20
+ term_df = pd .read_excel ('ICD-10-ICD-O.xlsx' )
21
+ term_df = term_df .loc [~ term_df ['Coding System' ].isin (['ICD-O-3行为学编码' , 'ICD-O-3组织学等级和分化程度编码' ])]
22
+ term_map = dict (zip (term_df ['Code' ], term_df ['释义' ]))
23
+ icd10_semantic_search = SemanticSearch ('./vs/icd10' )
24
+ icdo3_semantic_search = SemanticSearch ('./vs/icdo3' )
25
+
26
+
27
+ def search_reference0 (code ):
28
+ code_comps = code .split ('-' )
29
+ if len (code_comps ) > 1 :
30
+ code = '-' .join (code .split ('-' )[1 :])
31
+ return '\n ' .join ([line for line in code_book if re .search (re .escape (code ), line , re .I ) is not None ])
20
32
21
33
22
34
def search_reference (code ):
@@ -26,27 +38,53 @@ def search_reference(code):
26
38
return '\n ' .join ([line for line in code_book if re .search (re .escape (code ), line , re .I ) is not None ])
27
39
28
40
29
- def generate (context , chat_llm , callbacks ):
41
+ def generate (context , chat_llm , callbacks , output_container ):
30
42
system_message = SystemMessage (content = prompts .prompt1 )
31
- initial_user_prompt = PromptTemplate (template = prompts .prompt2 , input_variables = ['diagnosis' ]).format (context = context )
43
+ related_icd10 = icd10_semantic_search .search (context , k = 5 )
44
+ related_icdo3 = icdo3_semantic_search .search (context , k = 5 )
45
+ related_code_context = '\n ' .join ([f'{ x [0 ]["Code" ]} \n { x [0 ]["释义" ]} ' for x in related_icd10 + related_icdo3 ])
46
+ initial_user_prompt = PromptTemplate (template = prompts .prompt2 , input_variables = ['context' , 'related_codes' ]).format (context = context , related_codes = related_code_context )
32
47
initial_user_message = HumanMessage (content = initial_user_prompt )
48
+ output_container .chat_message ("user" ).write (system_message .content .replace ('\n ' , '\n \n ' ))
49
+ output_container .chat_message ("user" ).write (initial_user_message .content .replace ('\n ' , '\n \n ' ))
33
50
initial_result = chat_llm ([system_message , initial_user_message ], callbacks = callbacks )
34
51
second_user_message = HumanMessage (content = prompts .prompt3 )
52
+ output_container .chat_message ("user" ).write (second_user_message .content .replace ('\n ' , '\n \n ' ))
35
53
second_result = chat_llm ([system_message , initial_user_message , initial_result , second_user_message ], callbacks = callbacks )
36
54
code_result = second_result
37
55
try_cnt = 0
38
56
while True :
39
57
format_user_prompt = HumanMessage (content = prompts .prompt4 )
58
+ output_container .chat_message ("user" ).write (format_user_prompt .content .replace ('\n ' , '\n \n ' ))
40
59
format_result = chat_llm ([code_result , format_user_prompt ], callbacks = callbacks )
41
60
json_text = re .search ('```json(.+)```' , format_result .content , re .DOTALL )
42
61
if json_text is not None :
43
62
json_data = json .loads (json_text .group (1 ))
44
63
references = ''
45
64
for code in json_data ['code' ][:3 ]:
46
- ref = search_reference (code )
47
- references += f'{ code } :\n { ref } \n \n '
65
+ if re .search ('[\::]' , code ) is not None :
66
+ code = re .split ('[\::]' , code )[- 1 ].strip ()
67
+ if re .search ('^[A-Z]-' , code ):
68
+ code = '-' .join (code .split ('-' )[1 :])
69
+ ref = term_map .get (code )
70
+ if ref is not None :
71
+ references += f'{ code } :\n { ref } \n \n '
72
+ if re .search ('\d\.\-' , code ):
73
+ related_codes = [x for x in term_map .keys () if x .startswith (code .strip ('-' ))]
74
+ for rel in related_codes :
75
+ ref = term_map .get (rel )
76
+ references += f'{ rel } :\n { ref } \n \n '
77
+ if re .search ('\d\.\d\-\d' , code ):
78
+ code_compo = re .search ('(.*\d\.)(\d)\-(\d)' , code )
79
+ related_codes = [code_compo .group (1 ) + str (x ) for x in
80
+ range (int (code_compo .group (2 )), int (code_compo .group (3 )) + 1 )]
81
+ related_codes = [x for x in related_codes if x in term_map ]
82
+ for rel in related_codes :
83
+ ref = term_map .get (rel )
84
+ references += f'{ rel } :\n { ref } \n \n '
48
85
refine_user_prompt = PromptTemplate (template = prompts .prompt5 , input_variables = ['references' ]).format (references = references )
49
86
refine_user_message = HumanMessage (content = refine_user_prompt )
87
+ output_container .chat_message ("user" ).write (refine_user_message .content .replace ('\n ' , '\n \n ' ))
50
88
refine_result = chat_llm ([system_message , initial_user_message , initial_result , second_user_message , code_result , refine_user_message ], callbacks = callbacks )
51
89
code_result = refine_result
52
90
if '"confirmed": true' in code_result .content :
@@ -55,6 +93,7 @@ def generate(context, chat_llm, callbacks):
55
93
if try_cnt > 5 :
56
94
break
57
95
format_user_prompt = HumanMessage (content = prompts .prompt4 )
96
+ output_container .chat_message ("user" ).write (format_user_prompt .content .replace ('\n ' , '\n \n ' ))
58
97
format_result = chat_llm ([code_result , format_user_prompt ], callbacks = callbacks )
59
98
return format_result .content
60
99
@@ -74,8 +113,18 @@ def demo_page():
74
113
st_callback = CustomStreamlitCallbackHandler (output_container )
75
114
std_callback = StreamingStdOutCallbackHandler ()
76
115
callbacks = [st_callback , std_callback ]
77
- result = generate (raw_input , chat_llm , callbacks = callbacks )
116
+ result = generate (raw_input , chat_llm , callbacks , output_container )
78
117
st .markdown (result )
118
+ st .write ("""
119
+ <hr style="border: none; border-top: 1px solid #ccc;">
120
+ <p style="text-align: center; font-size: 12px;">
121
+ 沪ICP备18007075号-2
122
+ </p>
123
+ """ , unsafe_allow_html = True )
124
+ # st.write("""
125
+ # ---
126
+ # 沪ICP备18007075号-2
127
+ # """)
79
128
80
129
81
130
if __name__ == '__main__' :
0 commit comments