|
| 1 | +import time |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Callable, Generator, Literal |
| 4 | + |
| 5 | +import mesop as me |
| 6 | + |
| 7 | +_ROLE_USER = "user" |
| 8 | +_ROLE_ASSISTANT = "assistant" |
| 9 | + |
| 10 | +_BOT_USERNAME_DEFAULT = "mesop-bot" |
| 11 | + |
| 12 | +_COLOR_BACKGROUND = "#f0f4f8" |
| 13 | +_COLOR_CHAT_BUBBLE_YOU = "#f2f2f2" |
| 14 | +_COLOR_CHAT_BUBBLE_BOT = "#ebf3ff" |
| 15 | + |
| 16 | +_DEFAULT_PADDING = me.Padding(top=20, left=20, right=20, bottom=20) |
| 17 | +_DEFAULT_BORDER_SIDE = me.BorderSide( |
| 18 | + width="1px", style="solid", color="#ececec" |
| 19 | +) |
| 20 | + |
| 21 | +_LABEL_BUTTON = "Send prompt" |
| 22 | +_LABEL_BUTTON_IN_PROGRESS = "Processing prompt..." |
| 23 | +_LABEL_INPUT = "Enter your prompt" |
| 24 | + |
| 25 | +_STYLE_APP_CONTAINER = me.Style( |
| 26 | + background=_COLOR_BACKGROUND, |
| 27 | + display="grid", |
| 28 | + height="100vh", |
| 29 | + grid_template_columns="repeat(1, 1fr)", |
| 30 | +) |
| 31 | +_STYLE_TITLE = me.Style(padding=me.Padding(left=10)) |
| 32 | +_STYLE_CHAT_BOX = me.Style( |
| 33 | + height="100%", |
| 34 | + overflow_y="scroll", |
| 35 | + padding=_DEFAULT_PADDING, |
| 36 | + margin=me.Margin(bottom="20px"), |
| 37 | + border_radius="10px", |
| 38 | + border=me.Border( |
| 39 | + left=_DEFAULT_BORDER_SIDE, |
| 40 | + right=_DEFAULT_BORDER_SIDE, |
| 41 | + top=_DEFAULT_BORDER_SIDE, |
| 42 | + bottom=_DEFAULT_BORDER_SIDE, |
| 43 | + ), |
| 44 | +) |
| 45 | +_STYLE_CHAT_INPUT = me.Style(width="100%") |
| 46 | +_STYLE_CHAT_INPUT_BOX = me.Style(padding=me.Padding(top="30px")) |
| 47 | +_STYLE_CHAT_BUBBLE_NAME = me.Style( |
| 48 | + font_weight="bold", |
| 49 | + font_size="12px", |
| 50 | + padding=me.Padding(left="15px", right="15px", bottom="5px"), |
| 51 | +) |
| 52 | +_STYLE_CHAT_BUBBLE_PLAINTEXT = me.Style( |
| 53 | + margin=me.Margin(top="15px", bottom="15px") |
| 54 | +) |
| 55 | + |
| 56 | + |
| 57 | +def _make_style_chat_ui_container(has_title: bool) -> me.Style: |
| 58 | + """Generates styles for chat UI container depending on if there is a title or not. |
| 59 | +
|
| 60 | + Args: |
| 61 | + has_title: Whether the Chat UI is display a title or not. |
| 62 | + """ |
| 63 | + return me.Style( |
| 64 | + display="grid", |
| 65 | + grid_template_columns="repeat(1, 1fr)", |
| 66 | + grid_template_rows="1fr 14fr 1fr" if has_title else "5fr 1fr", |
| 67 | + margin=me.Margin(top=0, bottom=0, left="auto", right="auto"), |
| 68 | + width="min(1024px, 100%)", |
| 69 | + height="100vh", |
| 70 | + background="#fff", |
| 71 | + box_shadow=( |
| 72 | + "0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f" |
| 73 | + ), |
| 74 | + padding=_DEFAULT_PADDING, |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +def _make_style_chat_bubble_wrapper( |
| 79 | + role: Literal["user", "assistant"] |
| 80 | +) -> me.Style: |
| 81 | + """Generates styles for chat bubble position. |
| 82 | +
|
| 83 | + Args: |
| 84 | + role: Chat bubble alignment depends on the role |
| 85 | + """ |
| 86 | + align_items = "end" if role == _ROLE_USER else "start" |
| 87 | + return me.Style( |
| 88 | + display="flex", |
| 89 | + flex_direction="column", |
| 90 | + align_items=align_items, |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +def _make_chat_bubble_style(role: Literal["user", "assistant"]) -> me.Style: |
| 95 | + """Generates styles for chat bubble. |
| 96 | +
|
| 97 | + Args: |
| 98 | + role: Chat bubble background color depends on the role |
| 99 | + """ |
| 100 | + background = ( |
| 101 | + _COLOR_CHAT_BUBBLE_YOU if role == _ROLE_USER else _COLOR_CHAT_BUBBLE_BOT |
| 102 | + ) |
| 103 | + return me.Style( |
| 104 | + width="80%", |
| 105 | + font_size="13px", |
| 106 | + background=background, |
| 107 | + border_radius="15px", |
| 108 | + padding=me.Padding( |
| 109 | + top="0px", |
| 110 | + right="15px", |
| 111 | + left="15px", |
| 112 | + bottom="3px", |
| 113 | + ), |
| 114 | + margin=me.Margin(bottom="10px"), |
| 115 | + border=me.Border( |
| 116 | + left=_DEFAULT_BORDER_SIDE, |
| 117 | + right=_DEFAULT_BORDER_SIDE, |
| 118 | + top=_DEFAULT_BORDER_SIDE, |
| 119 | + bottom=_DEFAULT_BORDER_SIDE, |
| 120 | + ), |
| 121 | + ) |
| 122 | + |
| 123 | + |
| 124 | +@dataclass |
| 125 | +class Message: |
| 126 | + """Chat message metadata.""" |
| 127 | + |
| 128 | + role: Literal["assistant", "user"] = "user" |
| 129 | + username: str = "" |
| 130 | + content: str = "" |
| 131 | + |
| 132 | + |
| 133 | +@me.stateclass |
| 134 | +class State: |
| 135 | + input: str |
| 136 | + output: list[Message] |
| 137 | + in_progress: bool = False |
| 138 | + |
| 139 | + |
| 140 | +def on_input_update(State): |
| 141 | + """Generic on input handler that saves input to State using the given key. |
| 142 | +
|
| 143 | + This helper only works if you have one state instance. If use multiple state classes |
| 144 | + with this helper, then only the last event handler will be stored. For more info, see |
| 145 | + https://google.github.io/mesop/guides/troubleshooting/#avoid-using-closure-variables-in-event-handler. |
| 146 | + """ |
| 147 | + |
| 148 | + def _on_update( |
| 149 | + e: ( |
| 150 | + me.InputEvent |
| 151 | + | me.SelectSelectionChangeEvent |
| 152 | + | me.RadioChangeEvent |
| 153 | + | me.CheckboxChangeEvent |
| 154 | + ), |
| 155 | + ): |
| 156 | + state = me.state(State) |
| 157 | + setattr(state, e.key, e.value) |
| 158 | + |
| 159 | + return _on_update |
| 160 | + |
| 161 | + |
| 162 | +def chat( |
| 163 | + transform: Callable[[str], Generator[str, None, None] | str], |
| 164 | + *, |
| 165 | + title: str | None = None, |
| 166 | + bot_username: str = _BOT_USERNAME_DEFAULT, |
| 167 | +): |
| 168 | + """Creates a simple chat UI which takes in a prompt and returns a response to the |
| 169 | + prompt. |
| 170 | +
|
| 171 | + This function creates event handlers for text input and output operations |
| 172 | + using the provided function `transform` to process the input and generate the output. |
| 173 | +
|
| 174 | + Args: |
| 175 | + transform: Function that takes in a prompt and returns a response to the prompt. |
| 176 | + title: Headline text to display at the top of the UI. |
| 177 | + bot_username: Name of your bot / assistant. |
| 178 | + """ |
| 179 | + state = me.state(State) |
| 180 | + |
| 181 | + def on_click(e: me.ClickEvent): |
| 182 | + state = me.state(State) |
| 183 | + if state.in_progress or not state.input: |
| 184 | + return |
| 185 | + |
| 186 | + output = state.output |
| 187 | + if output is None: |
| 188 | + output = [] |
| 189 | + output.append(Message(role=_ROLE_USER, content=state.input)) |
| 190 | + state.in_progress = True |
| 191 | + yield |
| 192 | + |
| 193 | + start_time = time.time() |
| 194 | + output_message = transform(state.input) |
| 195 | + assistant_message = Message(role=_ROLE_ASSISTANT, username=bot_username) |
| 196 | + output.append(assistant_message) |
| 197 | + state.output = output |
| 198 | + for content in output_message: |
| 199 | + assistant_message.content += content |
| 200 | + if (time.time() - start_time) >= 0.5: |
| 201 | + start_time = time.time() |
| 202 | + yield |
| 203 | + state.in_progress = False |
| 204 | + yield |
| 205 | + |
| 206 | + with me.box(style=_STYLE_APP_CONTAINER): |
| 207 | + with me.box(style=_make_style_chat_ui_container(bool(title))): |
| 208 | + if title: |
| 209 | + me.text(title, type="headline-5", style=_STYLE_TITLE) |
| 210 | + with me.box(style=_STYLE_CHAT_BOX): |
| 211 | + for msg in state.output: |
| 212 | + with me.box(style=_make_style_chat_bubble_wrapper(msg.role)): |
| 213 | + if msg.role == _ROLE_ASSISTANT: |
| 214 | + me.text(msg.username, style=_STYLE_CHAT_BUBBLE_NAME) |
| 215 | + with me.box(style=_make_chat_bubble_style(msg.role)): |
| 216 | + if msg.role == _ROLE_USER: |
| 217 | + me.text(msg.content, style=_STYLE_CHAT_BUBBLE_PLAINTEXT) |
| 218 | + else: |
| 219 | + me.markdown(msg.content) |
| 220 | + |
| 221 | + with me.box(style=_STYLE_CHAT_INPUT_BOX): |
| 222 | + me.input( |
| 223 | + label=_LABEL_INPUT, |
| 224 | + key="input", |
| 225 | + on_input=on_input_update(State), |
| 226 | + style=_STYLE_CHAT_INPUT, |
| 227 | + ) |
| 228 | + with me.box(): |
| 229 | + me.button( |
| 230 | + _LABEL_BUTTON_IN_PROGRESS if state.in_progress else _LABEL_BUTTON, |
| 231 | + color="primary", |
| 232 | + type="flat", |
| 233 | + disabled=state.in_progress, |
| 234 | + on_click=on_click, |
| 235 | + ) |
0 commit comments