Skip to content

Commit ce3b30f

Browse files
authored
feat: add support for chat sessions (#167)
# What does this PR do? today, to chat with a model, a user has to run one command per completion. add --session which enables a user to have an interactive chat session with the inference model. --session can be passed with or without --message. If no --message is passed, the user is prompted to give the first message. This is useful as it also enables context to be saved between each completion unlike today. ``` llama-stack-client inference chat-completion --session >>> hi whats up! Assistant> Not much! How's your day going so far? Is there something I can help you with or would you like to chat? >>> what color is the sky? Assistant> The color of the sky can vary depending on the time of day and atmospheric conditions. Here are some common colors you might see: * During the daytime, when the sun is overhead, the sky typically appears blue. * At sunrise and sunset, the sky can take on hues of red, orange, pink, and purple due to the scattering of light by atmospheric particles. * On a clear day with no clouds, the sky can appear a bright blue, often referred to as "cerulean." * In areas with high levels of pollution or dust, the sky can appear more hazy or grayish. * At night, the sky can be dark and black, although some stars and moonlight can make it visible. So, what's your favorite color of the sky? >>> ``` ## Test Plan tested locally with and without --message Signed-off-by: Charlie Doern <[email protected]>
1 parent 04bfdbe commit ce3b30f

File tree

1 file changed

+51
-13
lines changed

1 file changed

+51
-13
lines changed

src/llama_stack_client/lib/cli/inference/inference.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import Optional, List, Dict
8+
import traceback
89

910
import click
1011
from rich.console import Console
@@ -19,30 +20,67 @@ def inference():
1920

2021

2122
@click.command("chat-completion")
22-
@click.option("--message", required=True, help="Message")
23+
@click.option("--message", help="Message")
2324
@click.option("--stream", is_flag=True, help="Streaming", default=False)
25+
@click.option("--session", is_flag=True, help="Start a Chat Session", default=False)
2426
@click.option("--model-id", required=False, help="Model ID")
2527
@click.pass_context
2628
@handle_client_errors("inference chat-completion")
27-
def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]):
29+
def chat_completion(ctx, message: str, stream: bool, session: bool, model_id: Optional[str]):
2830
"""Show available inference chat completion endpoints on distribution endpoint"""
31+
if not message and not session:
32+
click.secho(
33+
"you must specify either --message or --session",
34+
fg="red",
35+
)
36+
raise click.exceptions.Exit(1)
2937
client = ctx.obj["client"]
3038
console = Console()
3139

3240
if not model_id:
3341
available_models = [model.identifier for model in client.models.list() if model.model_type == "llm"]
3442
model_id = available_models[0]
3543

36-
response = client.inference.chat_completion(
37-
model_id=model_id,
38-
messages=[{"role": "user", "content": message}],
39-
stream=stream,
40-
)
41-
if not stream:
42-
console.print(response)
43-
else:
44-
for event in EventLogger().log(response):
45-
event.print()
44+
messages = []
45+
if message:
46+
messages.append({"role": "user", "content": message})
47+
response = client.inference.chat_completion(
48+
model_id=model_id,
49+
messages=messages,
50+
stream=stream,
51+
)
52+
if not stream:
53+
console.print(response)
54+
else:
55+
for event in EventLogger().log(response):
56+
event.print()
57+
if session:
58+
chat_session(client=client, model_id=model_id, messages=messages, console=console)
59+
60+
61+
def chat_session(client, model_id: Optional[str], messages: List[Dict[str, str]], console: Console):
62+
"""Run an interactive chat session with the served model"""
63+
while True:
64+
try:
65+
message = input(">>> ")
66+
if message in ["\\q", "quit"]:
67+
console.print("Exiting")
68+
break
69+
messages.append({"role": "user", "content": message})
70+
response = client.inference.chat_completion(
71+
model_id=model_id,
72+
messages=messages,
73+
stream=True,
74+
)
75+
for event in EventLogger().log(response):
76+
event.print()
77+
except Exception as exc:
78+
traceback.print_exc()
79+
console.print(f"Error in chat session {exc}")
80+
break
81+
except KeyboardInterrupt as exc:
82+
console.print("\nDetected user interrupt, exiting")
83+
break
4684

4785

4886
# Register subcommands

0 commit comments

Comments
 (0)