123
123
}
124
124
125
125
126
+ # Highlight: Part 1. Post-Processing Functions for LLM Outputs
127
+
126
128
def get_answer (text , answer_prompt , ignore_case = False ):
127
129
if ignore_case :
128
130
idx = text .lower ().rfind (answer_prompt .lower ())
@@ -142,22 +144,6 @@ def get_consensus(answers):
142
144
return max (counts , key = counts .get )
143
145
144
146
145
- def compare_molecule (smi1 , smi2 ) -> bool :
146
- from rdkit import Chem
147
- from rdkit .Chem import AllChem
148
-
149
- mol1 = Chem .MolFromSmiles (smi1 )
150
- mol2 = Chem .MolFromSmiles (smi2 )
151
- if mol1 is None or mol2 is None :
152
- return False
153
- else :
154
- return Chem .MolToSmiles (Chem .RemoveHs (mol1 )) == Chem .MolToSmiles (Chem .RemoveHs (mol2 ))
155
- # return False
156
- # fp1 = AllChem.GetMorganFingerprint(mol1, 2)
157
- # fp2 = AllChem.GetMorganFingerprint(mol2, 2)
158
- # return DataStructs.TanimotoSimilarity(fp1, fp2)
159
-
160
-
161
147
def normalize (s : str ) -> str :
162
148
"""Lower text and remove punctuation, articles and extra whitespace."""
163
149
s = s .lower ()
@@ -168,6 +154,87 @@ def normalize(s: str) -> str:
168
154
return s
169
155
170
156
157
+ def fuzzy_normalize_name (s ):
158
+ if s .startswith ("Unnamed" ):
159
+ return ""
160
+ else :
161
+ """ 标准化字符串 """
162
+ # # 定义需要移除的单位和符号
163
+ # units = ["µM", "µg/mL", "nM", "%", "wt.%", "at.%", "at%", "wt%"]
164
+ # for unit in units:
165
+ # s = s.replace(unit, "")
166
+
167
+ # 定义特定关键字
168
+ keywords = ["pIC50" , "IC50" , "EC50" , "TC50" , "GI50" , "Ki" , "Kd" , "Kb" , "pKb" ]
169
+
170
+ # 移除非字母数字的字符,除了空格
171
+ s = re .sub (r'[^\w\s%.\-\(\)]' , '' , s )
172
+ if s in synonyms :
173
+ s = synonyms [s ]
174
+
175
+ # 分割字符串为单词列表
176
+ words = s .split ()
177
+
178
+ # 将关键字移到末尾
179
+ reordered_words = [word for word in words if word not in keywords ]
180
+ keywords_in_string = [word for word in words if word in keywords ]
181
+ reordered_words .extend (keywords_in_string )
182
+ # 重新组合为字符串
183
+ return ' ' .join (reordered_words )
184
+
185
+
186
+ def fuzzy_normalize_value (vi ):
187
+ try :
188
+ vi = str (vi ).lower ()
189
+
190
+ if "bal" in vi or "remainder" in vi or "bas" in vi :
191
+ vi = "bal"
192
+ return "bal"
193
+
194
+ if ("nan" in vi and not "–" in vi ) or "/" == vi or "n/a" in vi or "na" in vi or vi == "" :
195
+ vi = "0"
196
+ vi = vi .replace ("nan" , "–" ).replace ("~" , "-" )
197
+
198
+ pattern = r"\d+(?:\.\d+)?"
199
+ matches = re .findall (pattern , vi )
200
+ if len (matches ) == 2 :
201
+ vi = f"{ matches [0 ]} -{ matches [1 ]} "
202
+ elif len (matches ) == 1 :
203
+ vi = matches [0 ]
204
+
205
+ if "<" in vi :
206
+ vi = vi .replace ("<" , "" )
207
+ if ">" in vi :
208
+ vi = vi .replace (">" , "" )
209
+
210
+ try :
211
+ vi = float (vi )
212
+ vi = round (vi , 3 )
213
+ except :
214
+ # print(vi)
215
+ pass
216
+ except :
217
+ pass
218
+
219
+ return vi
220
+
221
+
222
+ def extract_choice_and_value (sampled ):
223
+ pattern = re .compile (r'\w\)\s\d+(?:\.\d+)?(?:\s?:\s?\d+(?:\.\d+)?)?\s?[°]?[CK]?' )
224
+ matches = pattern .findall (sampled )
225
+ if matches :
226
+ sampled0 = pattern .findall (sampled )[0 ]
227
+ else :
228
+ return "No answer."
229
+ if sampled0 is None or sampled0 == []:
230
+ pass
231
+ else :
232
+ sampled = sampled0 .replace ("°" , " " )
233
+ sampled = sampled .replace (" " , " " )
234
+ return sampled
235
+
236
+ # Part 2. Comparison Functions for Post-Processed LLM Outputs
237
+
171
238
def fuzzy_match (s1 : str , s2 : str ) -> bool :
172
239
s1 = normalize (s1 )
173
240
s2 = normalize (s2 )
@@ -264,69 +331,32 @@ def is_float(str):
264
331
pass
265
332
266
333
267
- def fuzzy_normalize_name (s ):
268
- if s .startswith ("Unnamed" ):
269
- return ""
270
- else :
271
- """ 标准化字符串 """
272
- # # 定义需要移除的单位和符号
273
- # units = ["µM", "µg/mL", "nM", "%", "wt.%", "at.%", "at%", "wt%"]
274
- # for unit in units:
275
- # s = s.replace(unit, "")
276
-
277
- # 定义特定关键字
278
- keywords = ["pIC50" , "IC50" , "EC50" , "TC50" , "GI50" , "Ki" , "Kd" , "Kb" , "pKb" ]
279
-
280
- # 移除非字母数字的字符,除了空格
281
- s = re .sub (r'[^\w\s%.\-\(\)]' , '' , s )
282
- if s in synonyms :
283
- s = synonyms [s ]
284
-
285
- # 分割字符串为单词列表
286
- words = s .split ()
287
-
288
- # 将关键字移到末尾
289
- reordered_words = [word for word in words if word not in keywords ]
290
- keywords_in_string = [word for word in words if word in keywords ]
291
- reordered_words .extend (keywords_in_string )
292
- # 重新组合为字符串
293
- return ' ' .join (reordered_words )
294
-
295
-
296
- def fuzzy_normalize_value (vi ):
297
- try :
298
- vi = str (vi ).lower ()
299
-
300
- if "bal" in vi or "remainder" in vi or "bas" in vi :
301
- vi = "bal"
302
- return "bal"
334
+ def compare_molecule_similarity (smi1 , smi2 ) -> dict :
335
+ from rdkit import Chem
336
+ from rdkit .Chem import AllChem
337
+ from rdkit import DataStructs
303
338
304
- if ("nan" in vi and not "–" in vi ) or "/" == vi or "n/a" in vi or "na" in vi or vi == "" :
305
- vi = "0"
306
- vi = vi .replace ("nan" , "–" ).replace ("~" , "-" )
339
+ mol1 = Chem .MolFromSmiles (re .sub (r'<.*>' , '' , str (smi1 ).strip ("`" )))
340
+ mol2 = Chem .MolFromSmiles (re .sub (r'<.*>' , '' , str (smi2 ).strip ("`" )))
307
341
308
- pattern = r"\d+(?:\.\d+)?"
309
- matches = re .findall (pattern , vi )
310
- if len (matches ) == 2 :
311
- vi = f"{ matches [0 ]} -{ matches [1 ]} "
312
- elif len (matches ) == 1 :
313
- vi = matches [0 ]
342
+ if mol1 is None or mol2 is None :
343
+ sim = 0.0
344
+ else :
345
+ fp1 = AllChem .GetMorganFingerprint (mol1 , 2 )
346
+ fp2 = AllChem .GetMorganFingerprint (mol2 , 2 )
347
+ sim = DataStructs .TanimotoSimilarity (fp1 , fp2 )
348
+ return {"score" : sim }
314
349
315
- if "<" in vi :
316
- vi = vi .replace ("<" , "" )
317
- if ">" in vi :
318
- vi = vi .replace (">" , "" )
319
350
320
- try :
321
- vi = float (vi )
322
- vi = round (vi , 3 )
323
- except :
324
- # print(vi)
325
- pass
326
- except :
327
- pass
351
+ def compare_molecule_strict (smi1 , smi2 ) -> bool :
352
+ from rdkit import Chem
328
353
329
- return vi
354
+ mol1 = Chem .MolFromSmiles (smi1 )
355
+ mol2 = Chem .MolFromSmiles (smi2 )
356
+ if mol1 is None or mol2 is None :
357
+ return False
358
+ else :
359
+ return Chem .MolToSmiles (Chem .RemoveHs (mol1 )) == Chem .MolToSmiles (Chem .RemoveHs (mol2 ))
330
360
331
361
332
362
def tableMatching (df_ref , df_prompt , index = 'Compound' , compare_fields = [], record = True , file_name = None ):
@@ -350,7 +380,7 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
350
380
Match the indices of two dataframes.
351
381
"""
352
382
renames = {}
353
- name2query = lambda name : name if type (name ) != tuple else name [0 ] if len (name ) == 1 or name [1 ] == "" else name [1 ]
383
+ name2query = lambda name : name if type (name ) != tuple else name [0 ] if len (name ) == 1 or name [- 1 ] == "" else name [- 1 ]
354
384
similarities = np .array (np .ones ([len (ind0 ) + 15 , len (ind1 ) + 15 ]), dtype = np .float64 )
355
385
querys0 = [name2query (name ) for name in ind0 ]
356
386
querys1 = [name2query (name ) for name in ind1 ]
@@ -434,7 +464,7 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
434
464
except :
435
465
p = 'not found'
436
466
437
- _is_matching = fuzzy_compare_name (gt , p , compare_value = True ) if col != "SMILES" else compare_molecule (gt , p )
467
+ _is_matching = fuzzy_compare_name (gt , p , compare_value = True ) if col != "SMILES" else compare_molecule_strict (gt , p )
438
468
if col == "SMILES" :
439
469
smiles_match_score += float (_is_matching )
440
470
if record :
@@ -558,6 +588,38 @@ def count_leaves(d, count=0):
558
588
return 0
559
589
ratio = total_diff_leaves / total_leaves_dict1
560
590
591
+ if total_diff_leaves == total_leaves_dict1 and len (list (dict_ref .keys ())) == len (list (dict_prompt .keys ())):
592
+ values1 = list (dict_ref .values ())
593
+ values2 = list (dict_prompt .values ())
594
+
595
+ # Initialize containers for differences
596
+ differences = []
597
+
598
+ # The maximum length to iterate over
599
+ max_length = max (len (values1 ), len (values2 ))
600
+
601
+ total_diff_leaves = 0
602
+
603
+ for i in range (max_length ):
604
+ try :
605
+ value1 = values1 [i ]
606
+ value2 = values2 [i ]
607
+ except IndexError :
608
+ # Handle cases where the lists have different lengths
609
+ differences .append ('Different number of elements.' )
610
+ break
611
+
612
+ # If both values are dictionaries, use DeepDiff to compare them deeply
613
+ if isinstance (value1 , dict ) and isinstance (value2 , dict ):
614
+ diff = DeepDiff (value1 , value2 , ignore_order = True , report_repetition = True )
615
+ if diff :
616
+ total_diff_leaves += sum (len (diff .get (key , {})) for key in diff_keys )
617
+ differences .append (diff )
618
+ elif value1 != value2 :
619
+ total_diff_leaves += 1
620
+ # For non-dictionary values, just compare them directly
621
+ differences .append ({'different_values' : (value1 , value2 )})
622
+
561
623
return 1.0 - ratio , diff
562
624
563
625
@@ -863,6 +925,7 @@ def macro_f1_score_3(model, prediction: List[List[Any]], answers: List[List[Any]
863
925
except :
864
926
return 0.0
865
927
928
+
866
929
def scrub_formatting_from_prompt (prompt ):
867
930
scrubbed_prompt = copy .copy (prompt )
868
931
0 commit comments