Skip to content

Commit 74aa8ae

Browse files
authored
Add code tagger (lm-sys#3218)
1 parent c5223e3 commit 74aa8ae

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

fastchat/serve/monitor/code_tagger.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import re
2+
import json
3+
import argparse
4+
import multiprocessing as mp
5+
6+
import nltk
7+
from tqdm import tqdm
8+
from nltk.tokenize import word_tokenize
9+
10+
11+
def is_code_conversation(text: str) -> tuple[bool, list[str]]:
12+
"""Check if the text is a code conversation"""
13+
14+
if "```plaintext" in text:
15+
lines = text.split("\n")
16+
line1_idx = [idx for idx, line in enumerate(lines) if "```plaintext" in line][0]
17+
line2_idx = [
18+
line1_idx + 1 + idx
19+
for idx, line in enumerate(lines)
20+
if "```" in line[line1_idx + 1 :]
21+
]
22+
if line2_idx:
23+
line2_idx = line2_idx[0]
24+
text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :])
25+
else:
26+
text = "\n".join(lines[:line1_idx])
27+
return is_code_conversation(text)
28+
29+
if "```markdown" in text:
30+
otext = text
31+
lines = text.split("\n")
32+
line1_idx = [idx for idx, line in enumerate(lines) if "```markdown" in line][0]
33+
line2_idx = [
34+
line1_idx + 1 + idx
35+
for idx, line in enumerate(lines)
36+
if "```" in line[line1_idx + 1 :]
37+
]
38+
if line2_idx:
39+
line2_idx = line2_idx[0]
40+
text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :])
41+
else:
42+
text = "\n".join(lines[:line1_idx])
43+
return is_code_conversation(text)
44+
45+
if "ascii art" in text.lower():
46+
return False, []
47+
48+
# 1. Check for code formatting
49+
if re.search(r"```", text):
50+
return True, ["backticks"]
51+
52+
# Tokenize the text
53+
tokens = word_tokenize(text)
54+
tokens = [token.lower() for token in tokens]
55+
56+
# 2. Check for programming concepts
57+
concepts = ["git", "github", "pull request", "dataframe", "nginx", "pip"]
58+
if any(concept in tokens for concept in concepts):
59+
matched_concepts = list(set(tokens).intersection(set(concepts)))
60+
return True, matched_concepts
61+
62+
# 3. Check for programming language name
63+
languages = [
64+
"python",
65+
"c++",
66+
"cpp",
67+
"java",
68+
"javascript",
69+
"typescript",
70+
"html",
71+
"css",
72+
"sql",
73+
"bash",
74+
"powershell",
75+
"matlab",
76+
"golang",
77+
"linux",
78+
"ubuntu",
79+
]
80+
if any(language in tokens for language in languages):
81+
matched_languages = list(set(tokens).intersection(set(languages)))
82+
return True, matched_languages
83+
84+
# 4. Programming concept substrings
85+
strings = [
86+
"import pandas",
87+
"import numpy",
88+
"import torch",
89+
"jax",
90+
"tensorflow",
91+
"pytorch",
92+
"keras",
93+
"scikit-learn",
94+
"sklearn",
95+
" apt-get ",
96+
]
97+
found_array = [string in text for string in strings]
98+
if any(found_array):
99+
matched_strings = [
100+
string for string, found in zip(strings, found_array) if found
101+
]
102+
return True, matched_strings
103+
104+
# 5. Programming concept regexes
105+
regexes = [
106+
r"from \w+ import \w+",
107+
r"conda install \w+",
108+
r"pip install -r \w+",
109+
r"conda install -c \w+ \w+",
110+
r"#include <\w+>",
111+
r"import \w+ as \w+",
112+
r"#include \"\w+\.h\"",
113+
]
114+
found_array = [re.search(regex, text) for regex in regexes]
115+
if any(found_array):
116+
matched_regexes = [regex for regex, found in zip(regexes, found_array) if found]
117+
return True, matched_regexes
118+
119+
return False, []
120+
121+
122+
def check_code_conv(conv) -> tuple[bool, list[str]]:
123+
"""Check if the conversation is a code conversation"""
124+
for _, msg in enumerate(conv):
125+
content = msg["content"]
126+
if not isinstance(content, str):
127+
continue
128+
is_code_conv_res = is_code_conversation(content)
129+
if is_code_conv_res[0]:
130+
return is_code_conv_res
131+
return False, []
132+
133+
134+
def check_conv_row(conv_row):
135+
check_a, code_a = check_code_conv(conv_row["conversation_a"])
136+
check_b, code_b = check_code_conv(conv_row["conversation_b"])
137+
138+
return check_a or check_b, code_a + code_b
139+
140+
141+
def process_battle_file(battle_file_path: str, n_cpus: int):
142+
with open(battle_file_path, "r") as f:
143+
data = json.load(f)
144+
145+
with mp.Pool(n_cpus) as pool:
146+
tagged_data = list(tqdm(pool.imap(check_conv_row, data), total=len(data)))
147+
148+
output_data = [row for row, (is_code, _) in zip(data, tagged_data) if is_code]
149+
150+
return output_data
151+
152+
153+
if __name__ == "__main__":
154+
parser = argparse.ArgumentParser()
155+
parser.add_argument("--clean-battle-file", type=str)
156+
parser.add_argument("--output-clean-battle-file", type=str, default=None)
157+
parser.add_argument("--n-cpus", type=int, default=-1)
158+
159+
args = parser.parse_args()
160+
161+
if args.output_clean_battle_file is None:
162+
args.output_clean_battle_file = args.clean_battle_file
163+
164+
if args.n_cpus == -1:
165+
args.n_cpus = mp.cpu_count()
166+
167+
print(
168+
f"Processing {args.clean_battle_file} and saving to {args.output_clean_battle_file} with {args.n_cpus} cpus"
169+
)
170+
171+
output_data = process_battle_file(args.clean_battle_file, args.n_cpus)
172+
173+
with open(args.output_clean_battle_file, "w") as f:
174+
json.dump(output_data, f, indent=4)
175+
176+
print(f"Total code conversations: {len(output_data)}")
177+
print("Done!")
178+
179+
with open(args.output_clean_battle_file, "r") as f:
180+
data = json.load(f)

0 commit comments

Comments
 (0)