Skip to content

Commit 9a9a6ef

Browse files
authored
Misc changes for vision arena: change log, fix image moderation (lm-sys#3341)
1 parent 7393624 commit 9a9a6ef

File tree

7 files changed

+266
-69
lines changed

7 files changed

+266
-69
lines changed

fastchat/conversation.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class Conversation:
6969
stop_str: Union[str, List[str]] = None
7070
# Stops generation if meeting any token in this list
7171
stop_token_ids: List[int] = None
72+
# The maximum image size in megabytes that this model takes in. None means we do not resize the image.
73+
max_image_size_mb: int = None
7274

7375
def get_prompt(self) -> str:
7476
"""Get the prompt for generation."""
@@ -351,10 +353,11 @@ def update_last_message(self, message: str):
351353
"""
352354
self.messages[-1][1] = message
353355

354-
def convert_image_to_base64(self, image, resize_image=False):
356+
def convert_image_to_base64(self, image):
355357
"""Given an image, return the base64 encoded image string."""
356358
from PIL import Image
357359
import requests
360+
from fastchat.utils import resize_image_and_return_image_in_bytes
358361

359362
# Load image if it has not been loaded in yet
360363
if type(image) == str:
@@ -367,23 +370,10 @@ def convert_image_to_base64(self, image, resize_image=False):
367370
else:
368371
image = Image.open(image).convert("RGB")
369372

370-
if resize_image:
371-
max_hw, min_hw = max(image.size), min(image.size)
372-
aspect_ratio = max_hw / min_hw
373-
max_len, min_len = 2048, 2048
374-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
375-
longest_edge = int(shortest_edge * aspect_ratio)
376-
W, H = image.size
377-
if longest_edge != max(image.size):
378-
if H > W:
379-
H, W = longest_edge, shortest_edge
380-
else:
381-
H, W = shortest_edge, longest_edge
382-
image = image.resize((W, H))
383-
384-
buffered = BytesIO()
385-
image.save(buffered, format="PNG")
386-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
373+
image_bytes = resize_image_and_return_image_in_bytes(
374+
image, self.max_image_size_mb
375+
)
376+
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
387377

388378
return img_b64_str
389379

@@ -400,7 +390,7 @@ def to_gradio_chatbot(self):
400390
):
401391
img_str = f'<img src="{img_b64_str}" alt="user upload image" />'
402392
else:
403-
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
393+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
404394
msg = img_str + msg.replace("<image>\n", "").strip()
405395

406396
ret.append([msg, None])
@@ -429,7 +419,7 @@ def to_openai_image_format(self, image_urls):
429419
base64.b64encode(base64.b64decode(image_url))
430420
== image_url.encode()
431421
), "The image data is not a valid base64 encoded string"
432-
openai_images.append(f"data:image/jpeg;base64,{image_url}")
422+
openai_images.append(f"data:image/png;base64,{image_url}")
433423
except:
434424
raise ValueError(
435425
f"This file is not valid or not currently supported by the OpenAI API: {image_url}"
@@ -438,12 +428,16 @@ def to_openai_image_format(self, image_urls):
438428

439429
def to_openai_vision_api_messages(self):
440430
"""Convert the conversation to OpenAI vision api completion format"""
441-
ret = [
442-
{
443-
"role": "system",
444-
"content": [{"type": "text", "text": self.system_message}],
445-
}
446-
]
431+
if self.system_message == "":
432+
ret = []
433+
else:
434+
ret = [
435+
{
436+
"role": "system",
437+
"content": [{"type": "text", "text": self.system_message}],
438+
}
439+
]
440+
447441
for i, (_, msg) in enumerate(self.messages[self.offset :]):
448442
if i % 2 == 0:
449443
if type(msg) is tuple:
@@ -598,7 +592,7 @@ def to_reka_api_messages(self):
598592
{
599593
"type": "human",
600594
"text": text,
601-
"media_url": f"data:image/jpeg;base64,{image}",
595+
"media_url": f"data:image/png;base64,{image}",
602596
}
603597
)
604598
else:
@@ -680,6 +674,7 @@ def copy(self):
680674
sep2=self.sep2,
681675
stop_str=self.stop_str,
682676
stop_token_ids=self.stop_token_ids,
677+
max_image_size_mb=self.max_image_size_mb,
683678
)
684679

685680
def dict(self):
@@ -1078,6 +1073,7 @@ def get_conv_template(name: str) -> Conversation:
10781073
roles=("user", "assistant"),
10791074
sep_style=SeparatorStyle.DEFAULT,
10801075
sep=None,
1076+
max_image_size_mb=None, # OpenAI does auto-resizing
10811077
)
10821078
)
10831079

@@ -1115,6 +1111,7 @@ def get_conv_template(name: str) -> Conversation:
11151111
roles=("Human", "Assistant"),
11161112
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
11171113
sep="\n\n",
1114+
max_image_size_mb=5 / 1.35,
11181115
)
11191116
)
11201117

@@ -1138,6 +1135,7 @@ def get_conv_template(name: str) -> Conversation:
11381135
roles=("user", "assistant"),
11391136
sep_style=SeparatorStyle.DEFAULT,
11401137
sep=None,
1138+
max_image_size_mb=5 / 1.35,
11411139
)
11421140
)
11431141

@@ -1161,6 +1159,7 @@ def get_conv_template(name: str) -> Conversation:
11611159
roles=("user", "assistant"),
11621160
sep_style=SeparatorStyle.DEFAULT,
11631161
sep=None,
1162+
max_image_size_mb=5 / 1.35,
11641163
)
11651164
)
11661165

@@ -1193,6 +1192,7 @@ def get_conv_template(name: str) -> Conversation:
11931192
roles=("user", "assistant"),
11941193
sep_style=SeparatorStyle.DEFAULT,
11951194
sep=None,
1195+
max_image_size_mb=5 / 1.35,
11961196
)
11971197
)
11981198

@@ -1288,6 +1288,7 @@ def get_conv_template(name: str) -> Conversation:
12881288
roles=("user", "model"),
12891289
sep_style=SeparatorStyle.DEFAULT,
12901290
sep=None,
1291+
max_image_size_mb=20,
12911292
)
12921293
)
12931294

fastchat/serve/gradio_block_arena_vision.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
image_moderation_filter,
4141
)
4242

43-
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
43+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
4444

4545
no_change_btn = gr.Button()
4646
enable_btn = gr.Button(interactive=True, visible=True)
@@ -82,7 +82,7 @@ def add_image(textbox):
8282

8383

8484
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
85-
filename = get_conv_log_filename(state.is_vision)
85+
filename = get_conv_log_filename(state.is_vision, state.has_csam_image)
8686
with open(filename, "a") as fout:
8787
data = {
8888
"tstamp": round(time.time(), 4),
@@ -158,7 +158,7 @@ def moderate_input(text, all_conv_text, model_list, images, ip):
158158
elif text_flagged and image_flagged:
159159
text = MODERATION_MSG
160160

161-
return text, csam_flagged
161+
return text, image_flagged, csam_flagged
162162

163163

164164
def add_text(state, model_selector, chat_input, request: gr.Request):
@@ -176,10 +176,17 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
176176
all_conv_text = state.conv.get_prompt()
177177
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
178178

179-
text, csam_flag = moderate_input(
179+
text, image_flagged, csam_flag = moderate_input(
180180
text, all_conv_text, [state.model_name], images, ip
181181
)
182182

183+
if image_flagged:
184+
logger.info(f"image flagged. ip: {ip}. text: {text}")
185+
state.skip_next = True
186+
return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + (
187+
no_change_btn,
188+
) * 5
189+
183190
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
184191
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
185192
state.skip_next = True

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def clear_history_example(request: gr.Request):
168168

169169

170170
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
171-
filename = get_conv_log_filename(states[0].is_vision)
171+
filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image)
172172

173173
with open(filename, "a") as fout:
174174
data = {
@@ -309,7 +309,7 @@ def add_text(
309309
)
310310

311311
model_list = [states[i].model_name for i in range(num_sides)]
312-
text, csam_flag = moderate_input(text, text, model_list, images, ip)
312+
text, image_flagged, csam_flag = moderate_input(text, text, model_list, images, ip)
313313

314314
conv = states[0].conv
315315
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
@@ -327,6 +327,21 @@ def add_text(
327327
+ [""]
328328
)
329329

330+
if image_flagged:
331+
logger.info(f"image flagged. ip: {ip}. text: {text}")
332+
for i in range(num_sides):
333+
states[i].skip_next = True
334+
return (
335+
states
336+
+ [x.to_gradio_chatbot() for x in states]
337+
+ [{"text": IMAGE_MODERATION_MSG}]
338+
+ [
339+
no_change_btn,
340+
]
341+
* 6
342+
+ [""]
343+
)
344+
330345
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
331346
for i in range(num_sides):
332347
post_processed_text = _prepare_text_with_image(

fastchat/serve/gradio_block_arena_vision_named.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@
5353
)
5454

5555

56-
logger = build_logger(
57-
"gradio_web_server_vision_multi", "gradio_web_server_vision_multi.log"
58-
)
56+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
5957

6058
num_sides = 2
6159
enable_moderation = False
@@ -72,7 +70,7 @@ def clear_history_example(request: gr.Request):
7270

7371

7472
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
75-
filename = get_conv_log_filename(states[0].is_vision)
73+
filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image)
7674
with open(filename, "a") as fout:
7775
data = {
7876
"tstamp": round(time.time(), 4),
@@ -189,7 +187,9 @@ def add_text(
189187
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text
190188
)
191189

192-
text, csam_flag = moderate_input(text, all_conv_text, model_list, images, ip)
190+
text, image_flagged, csam_flag = moderate_input(
191+
text, all_conv_text, model_list, images, ip
192+
)
193193

194194
conv = states[0].conv
195195
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
@@ -206,6 +206,20 @@ def add_text(
206206
* 6
207207
)
208208

209+
if image_flagged:
210+
logger.info(f"image flagged. ip: {ip}. text: {text}")
211+
for i in range(num_sides):
212+
states[i].skip_next = True
213+
return (
214+
states
215+
+ [x.to_gradio_chatbot() for x in states]
216+
+ [{"text": IMAGE_MODERATION_MSG}]
217+
+ [
218+
no_change_btn,
219+
]
220+
* 6
221+
)
222+
209223
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
210224
for i in range(num_sides):
211225
post_processed_text = _prepare_text_with_image(

fastchat/serve/gradio_web_server.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ def dict(self):
131131
{
132132
"conv_id": self.conv_id,
133133
"model_name": self.model_name,
134-
"has_csam_image": self.has_csam_image,
135134
}
136135
)
136+
137+
if self.is_vision:
138+
base.update({"has_csam_image": self.has_csam_image})
137139
return base
138140

139141

@@ -144,11 +146,13 @@ def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_):
144146
use_remote_storage = use_remote_storage_
145147

146148

147-
def get_conv_log_filename(is_vision=False):
149+
def get_conv_log_filename(is_vision=False, has_csam_image=False):
148150
t = datetime.datetime.now()
149151
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
150-
if is_vision:
152+
if is_vision and not has_csam_image:
151153
name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}")
154+
elif is_vision and has_csam_image:
155+
name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}")
152156
else:
153157
name = os.path.join(LOGDIR, conv_log_filename)
154158

@@ -306,10 +310,8 @@ def _prepare_text_with_image(state, text, images, csam_flag):
306310
# reset convo with new image
307311
state.conv = get_conversation_template(state.model_name)
308312

309-
resize_image = "llava" in state.model_name
310313
image = state.conv.convert_image_to_base64(
311-
image,
312-
resize_image=resize_image,
314+
image
313315
) # PIL type is not JSON serializable
314316

315317
if csam_flag:
@@ -572,7 +574,9 @@ def bot_response(
572574
has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage
573575
)
574576

575-
filename = get_conv_log_filename(is_vision=state.is_vision)
577+
filename = get_conv_log_filename(
578+
is_vision=state.is_vision, has_csam_image=state.has_csam_image
579+
)
576580

577581
with open(filename, "a") as fout:
578582
data = {

0 commit comments

Comments
 (0)