Skip to content

Commit 7d035d7

Browse files
committed
optimize code lookup
1 parent bcf19d6 commit 7d035d7

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

medcodegpt.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,28 @@ def search_reference(code):
3838
return '\n'.join([line for line in code_book if re.search(re.escape(code), line, re.I) is not None])
3939

4040

41+
def lookup_code(code):
42+
std_codes = list()
43+
if re.search('[\::]', code) is not None:
44+
code = re.split('[\::]', code)[-1].strip()
45+
if re.search('^[A-Z]-', code):
46+
code = '-'.join(code.split('-')[1:])
47+
if code in term_map:
48+
std_codes.append(code)
49+
if re.search('\d\.\-', code):
50+
related_codes = [x for x in term_map.keys() if x.startswith(code.strip('-'))]
51+
for rel in related_codes:
52+
if rel in term_map:
53+
std_codes.append(rel)
54+
if re.search('\d\.\d\-\d', code):
55+
code_compo = re.search('(.*\d\.)(\d)\-(\d)', code)
56+
related_codes = [code_compo.group(1) + str(x) for x in
57+
range(int(code_compo.group(2)), int(code_compo.group(3)) + 1)]
58+
related_codes = [x for x in related_codes if x in term_map]
59+
std_codes.extend(related_codes)
60+
return std_codes
61+
62+
4163
def generate(context, chat_llm, callbacks, output_container):
4264
system_message = SystemMessage(content=prompts.prompt1)
4365
related_icd10 = icd10_semantic_search.search(context, k=5)
@@ -60,28 +82,12 @@ def generate(context, chat_llm, callbacks, output_container):
6082
json_text = re.search('```json(.+)```', format_result.content, re.DOTALL)
6183
if json_text is not None:
6284
json_data = json.loads(json_text.group(1))
63-
references = ''
6485
for code in json_data['code'][:3]:
65-
if re.search('[\::]', code) is not None:
66-
code = re.split('[\::]', code)[-1].strip()
67-
if re.search('^[A-Z]-', code):
68-
code = '-'.join(code.split('-')[1:])
69-
ref = term_map.get(code)
70-
if ref is not None:
71-
references += f'{code}:\n{ref}\n\n'
72-
if re.search('\d\.\-', code):
73-
related_codes = [x for x in term_map.keys() if x.startswith(code.strip('-'))]
74-
for rel in related_codes:
75-
ref = term_map.get(rel)
76-
references += f'{rel}:\n{ref}\n\n'
77-
if re.search('\d\.\d\-\d', code):
78-
code_compo = re.search('(.*\d\.)(\d)\-(\d)', code)
79-
related_codes = [code_compo.group(1) + str(x) for x in
80-
range(int(code_compo.group(2)), int(code_compo.group(3)) + 1)]
81-
related_codes = [x for x in related_codes if x in term_map]
82-
for rel in related_codes:
83-
ref = term_map.get(rel)
84-
references += f'{rel}:\n{ref}\n\n'
86+
std_codes = lookup_code(code)
87+
if re.search('\s', code) is not None:
88+
for sub_code in re.split('\s', code):
89+
std_codes.extend(lookup_code(sub_code))
90+
references = ''.join([f'{rel}:\n{term_map[rel]}\n\n' for rel in set(std_codes)])
8591
refine_user_prompt = PromptTemplate(template=prompts.prompt5, input_variables=['references']).format(references=references)
8692
refine_user_message = HumanMessage(content=refine_user_prompt)
8793
output_container.chat_message("user").write(refine_user_message.content.replace('\n', '\n\n'))
@@ -121,10 +127,6 @@ def demo_page():
121127
沪ICP备18007075号-2
122128
</p>
123129
""", unsafe_allow_html=True)
124-
# st.write("""
125-
# ---
126-
# 沪ICP备18007075号-2
127-
# """)
128130

129131

130132
if __name__ == '__main__':

0 commit comments

Comments
 (0)