Skip to content

Commit f121918

Browse files
authored
Updates (6/3/2024) (#9)
1. Improved support for NIMs, now in GA! Sign up [here](https://catalog.ngc.nvidia.com/orgs/nim/teams/meta/containers/llama3-8b-instruct/tags) for access. - Replaced model name field with full container image/tag field for ease of use (eg. copy-paste). - Improved local NIM switchability by replacing the model field with the full NIM container path. Users can copy and paste their NIM container directly in the chat UI. - Improved Local NIM flow to replace model-repo-generate step with a NIM sidecar container pull step to better align with new NIM release. - Fixed an issue with Remote NIM support returning null token for vLLM-backend NIMs - Set defaults for the project settings to better align with the quickstart contents in the NIM documentation (now uses vLLM backend) 2. Improved Metrics Tracking - Removed "clear query" button to accommodate for Show Metrics panel functionality. - Added support for new metrics: - retrieval time (ms) - TTFT (ms) - generation time (ms) - E2E (ms) - approx. tokens in response - approx. tokens generated per second - approx. inter-token latency (ITL) 3. Expanded Cloud supported models (12 -> 18) - Added support for IBM's Granite Code models to better align with NVIDIA's API Catalog - Granite 8B Code Instruct - Granite 34B Code Instruct - Widened support for Microsoft's Phi-3 models to better align with NVIDIA's API Catalog - Phi-3 Mini (4k) - Phi-3 Small (8k) - Phi-3 Small (128k) - Phi-3 Medium (4k) - Implemented temporary workaround to fix an issue with Microsoft's Phi-3 model not supporting penalty parameters. 4. Expanded local model selection for locally-running RAG - Added ungated model for local HG TGI: microsoft/Phi-3-mini-128k-instruct - Add filtering option to filter local models dropdown by gated vs ungated models 5. Additional Output Customization - Added support for new Output Settings parameters: - top_p - frequency penalty - presence penalty - Increase max new tokens to generate to up to 2048 max tokens to generate (from 512) - Dynamic max new tokens to generate limits set depending on auto system introspection 6. General Usability - Improved UI clutter by turning some major UI components collapsible. - Right hand inference settings panel can collapse and expand to reduce clutter - Output parameters sliders now hidden by default to reduce clutter, but can be expanded - Improved error messaging and forwarding of issues to the frontend UI. - Increase timeouts to capture a broader range of user setups - Ongoing improvements in documentation of code.
1 parent 88aa53a commit f121918

18 files changed

+754
-315
lines changed

.project/spec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ execution:
6666
start_command: cd /project/code/ && PROXY_PREFIX=$PROXY_PREFIX $HOME/.conda/envs/ui-env/bin/python3
6767
-m chatui
6868
health_check_command: curl -f "http://localhost:8080/"
69-
stop_command: pkill -f '^$HOME/.conda/envs/ui-env/bin/python3 -m chatui'
69+
stop_command: pkill -f "^$HOME/.conda/envs/ui-env/bin/python3 -m chatui"
7070
user_msg: ""
7171
logfile_path: ""
7272
timeout_seconds: 0

README.md

Lines changed: 27 additions & 26 deletions
Large diffs are not rendered by default.

code/chain_server/chains.py

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@
9898
"Assistant: "
9999
)
100100

101+
MICROSOFT_CHAT_TEMPLATE = (
102+
"<|user|>\n"
103+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, dangerous, or illegal content. If you don't know the answer to a question, please don't share false information. Please ensure that your responses are positive in nature.\n"
104+
"The user's question is: {context_str} {query_str} <|end|> \n"
105+
"<|assistant|>"
106+
)
107+
108+
GENERIC_CHAT_TEMPLATE = (
109+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, dangerous, or illegal content. If you don't know the answer to a question, please don't share false information. Please ensure that your responses are positive in nature.\n"
110+
"The user's question is: {context_str} {query_str} <|end|> \n"
111+
)
112+
101113
MISTRAL_RAG_TEMPLATE = (
102114
"<s>[INST] <<SYS>>"
103115
"Use the following context to answer the user's question. If you don't know the answer,"
@@ -132,6 +144,22 @@
132144
"Assistant: "
133145
)
134146

147+
MICROSOFT_RAG_TEMPLATE = (
148+
"<|user|>\n"
149+
"Use the following context to answer the question. If you don't know the answer,"
150+
"just say that you don't know, don't try to make up an answer.\n"
151+
"Context: {context_str} Question: {query_str} Only return the helpful"
152+
" answer below and nothing else. <|end|> \n"
153+
"<|assistant|>"
154+
)
155+
156+
GENERIC_RAG_TEMPLATE = (
157+
"Use the following context to answer the question. If you don't know the answer,"
158+
"just say that you don't know, don't try to make up an answer.\n"
159+
"Context: {context_str} Question: {query_str} Only return the helpful"
160+
" answer below and nothing else. \n"
161+
)
162+
135163

136164
class LimitRetrievedNodesLength(BaseNodePostprocessor):
137165
"""Llama Index chain filter to limit token lengths."""
@@ -169,7 +197,7 @@ def get_config() -> "ConfigWizard":
169197

170198

171199
@lru_cache
172-
def get_llm(inference_mode: str, nvcf_model_id: str, nim_model_ip: str, num_tokens: int, temp: float) -> LangChainLLM:
200+
def get_llm(inference_mode: str, nvcf_model_id: str, nim_model_ip: str, num_tokens: int, temp: float, top_p: float, freq_pen: float) -> LangChainLLM:
173201
"""Create the LLM connection."""
174202

175203
if inference_mode == "local":
@@ -179,23 +207,23 @@ def get_llm(inference_mode: str, nvcf_model_id: str, nim_model_ip: str, num_toke
179207
inference_server_url=inference_server_url_local,
180208
max_new_tokens=num_tokens,
181209
top_k=10,
182-
top_p=0.95,
210+
top_p=top_p,
183211
typical_p=0.95,
184212
temperature=temp,
185-
repetition_penalty=1.03,
213+
repetition_penalty=(freq_pen/8) + 1, # Reasonable mapping of OpenAI API Spec to HF Spec
186214
streaming=True
187215
)
188216
else:
189-
inference_server_url_local = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":9999/v1/"
217+
inference_server_url_local = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":8000/v1/"
190218

191219
llm_local = HuggingFaceTextGenInference(
192220
inference_server_url=inference_server_url_local,
193221
max_new_tokens=num_tokens,
194222
top_k=10,
195-
top_p=0.95,
223+
top_p=top_p,
196224
typical_p=0.95,
197225
temperature=temp,
198-
repetition_penalty=1.03,
226+
repetition_penalty=(freq_pen/8) + 1, # Reasonable mapping of OpenAI API Spec to HF Spec
199227
streaming=True
200228
)
201229

@@ -237,10 +265,10 @@ def get_doc_retriever(num_nodes: int = 4) -> "BaseRetriever":
237265

238266

239267
@lru_cache
240-
def set_service_context(inference_mode: str, nvcf_model_id: str, nim_model_ip: str, num_tokens: int, temp: float) -> None:
268+
def set_service_context(inference_mode: str, nvcf_model_id: str, nim_model_ip: str, num_tokens: int, temp: float, top_p: float, freq_pen: float) -> None:
241269
"""Set the global service context."""
242270
service_context = ServiceContext.from_defaults(
243-
llm=get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp), embed_model=get_embedding_model()
271+
llm=get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp, top_p, freq_pen), embed_model=get_embedding_model()
244272
)
245273
set_global_service_context(service_context)
246274

@@ -255,23 +283,30 @@ def llm_chain_streaming(
255283
nim_model_port: str,
256284
nim_model_id: str,
257285
temp: float,
286+
top_p: float,
287+
freq_pen: float,
288+
pres_pen: float,
258289
) -> Generator[str, None, None]:
259290
"""Execute a simple LLM chain using the components defined above."""
260-
set_service_context(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp)
291+
set_service_context(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp, top_p, freq_pen)
261292

262293
if inference_mode == "local":
263-
if local_model_id == "nvidia/Llama3-ChatQA-1.5-8B":
294+
if "nvidia" in local_model_id:
264295
prompt = NVIDIA_CHAT_TEMPLATE.format(context_str=context, query_str=question)
265-
elif local_model_id == "meta-llama/Meta-Llama-3-8B-Instruct":
296+
elif "Llama-3" in local_model_id:
266297
prompt = LLAMA_3_CHAT_TEMPLATE.format(context_str=context, query_str=question)
267-
elif local_model_id == "meta-llama/Llama-2-7b-chat-hf":
298+
elif "Llama-2" in local_model_id:
268299
prompt = LLAMA_2_CHAT_TEMPLATE.format(context_str=context, query_str=question)
269-
else:
300+
elif "microsoft" in local_model_id:
301+
prompt = MICROSOFT_CHAT_TEMPLATE.format(context_str=context, query_str=question)
302+
elif "mistralai" in local_model_id:
270303
prompt = MISTRAL_CHAT_TEMPLATE.format(context_str=context, query_str=question)
304+
else:
305+
prompt = NVIDIA_CHAT_TEMPLATE.format(context_str=context, query_str=question)
271306
start = time.time()
272-
response = get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp).stream_complete(prompt, max_new_tokens=num_tokens)
307+
response = get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp, top_p, freq_pen).stream_complete(prompt, max_new_tokens=num_tokens)
273308
perf = time.time() - start
274-
yield str(perf * 1000).split('.', 1)[0] + "ms"
309+
yield str(perf * 1000).split('.', 1)[0]
275310
gen_response = (resp.delta for resp in response)
276311
for chunk in gen_response:
277312
if "<|eot_id|>" not in chunk:
@@ -285,26 +320,42 @@ def llm_chain_streaming(
285320
prompt = LLAMA_2_CHAT_TEMPLATE.format(context_str=context, query_str=question)
286321
elif inference_mode == "cloud" and "mistral" in nvcf_model_id:
287322
prompt = MISTRAL_CHAT_TEMPLATE.format(context_str=context, query_str=question)
288-
elif inference_mode == "cloud" and "google" in nvcf_model_id:
289-
prompt = MISTRAL_CHAT_TEMPLATE.format(context_str=context, query_str=question)
323+
elif inference_mode == "cloud" and "microsoft" in nvcf_model_id:
324+
prompt = MICROSOFT_CHAT_TEMPLATE.format(context_str=context, query_str=question)
290325
else:
291-
prompt = MISTRAL_CHAT_TEMPLATE.format(context_str=context, query_str=question)
326+
prompt = GENERIC_CHAT_TEMPLATE.format(context_str=context, query_str=question)
292327
openai.api_key = os.environ.get('NVCF_RUN_KEY') if inference_mode == "cloud" else "xyz"
293-
openai.base_url = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":" + ("9999" if len(nim_model_port) == 0 else nim_model_port) + "/v1/"
328+
openai.base_url = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":" + ("8000" if len(nim_model_port) == 0 else nim_model_port) + "/v1/"
294329

295330
start = time.time()
296331
completion = openai.chat.completions.create(
297-
model= nvcf_model_id if inference_mode == "cloud" else nim_model_id,
332+
model= nvcf_model_id if inference_mode == "cloud" else ("meta/llama3-8b-instruct" if len(nim_model_id) == 0 else nim_model_id),
298333
temperature=temp,
334+
top_p=top_p,
335+
# frequency_penalty=freq_pen, # Some models have yet to roll out support for these params
336+
# presence_penalty=pres_pen,
299337
messages=[{"role": "user", "content": prompt}],
300338
max_tokens=num_tokens,
301-
stream=True
339+
stream=True,
340+
) if inference_mode == "cloud" and "microsoft" in nvcf_model_id else openai.chat.completions.create(
341+
model= nvcf_model_id if inference_mode == "cloud" else ("meta/llama3-8b-instruct" if len(nim_model_id) == 0 else nim_model_id),
342+
temperature=temp,
343+
top_p=top_p,
344+
frequency_penalty=freq_pen,
345+
presence_penalty=pres_pen,
346+
messages=[{"role": "user", "content": prompt}],
347+
max_tokens=num_tokens,
348+
stream=True,
302349
)
303350
perf = time.time() - start
304-
yield str(perf * 1000).split('.', 1)[0] + "ms"
351+
yield str(perf * 1000).split('.', 1)[0]
305352

306353
for chunk in completion:
307-
yield chunk.choices[0].delta.content
354+
content = chunk.choices[0].delta.content
355+
if content is not None:
356+
yield str(content)
357+
else:
358+
continue
308359

309360
def rag_chain_streaming(prompt: str,
310361
num_tokens: int,
@@ -314,32 +365,41 @@ def rag_chain_streaming(prompt: str,
314365
nim_model_ip: str,
315366
nim_model_port: str,
316367
nim_model_id: str,
317-
temp: float) -> "TokenGen":
368+
temp: float,
369+
top_p: float,
370+
freq_pen: float,
371+
pres_pen: float) -> "TokenGen":
318372
"""Execute a Retrieval Augmented Generation chain using the components defined above."""
319-
set_service_context(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp)
373+
set_service_context(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp, top_p, freq_pen)
320374

321375
if inference_mode == "local":
322-
get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp).llm.max_new_tokens = num_tokens # type: ignore
323-
start = time.time()
376+
get_llm(inference_mode, nvcf_model_id, nim_model_ip, num_tokens, temp, top_p, freq_pen).llm.max_new_tokens = num_tokens # type: ignore
324377
nodes = get_doc_retriever(num_nodes=2).retrieve(prompt)
325378
docs = []
326379
for node in nodes:
327380
docs.append(node.get_text())
328-
if local_model_id == "nvidia/Llama3-ChatQA-1.5-8B":
381+
if "nvidia" in local_model_id:
329382
prompt = NVIDIA_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
330-
elif local_model_id == "meta-llama/Meta-Llama-3-8B-Instruct":
383+
elif "Llama-3" in local_model_id:
331384
prompt = LLAMA_3_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
332-
elif local_model_id == "meta-llama/Llama-2-7b-chat-hf":
385+
elif "Llama-2" in local_model_id:
333386
prompt = LLAMA_2_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
334-
else:
387+
elif "microsoft" in local_model_id:
388+
prompt = MICROSOFT_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
389+
elif "mistralai" in local_model_id:
335390
prompt = MISTRAL_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
391+
else:
392+
prompt = NVIDIA_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
393+
start = time.time()
336394
response = get_llm(inference_mode,
337395
nvcf_model_id,
338396
nim_model_ip,
339397
num_tokens,
340-
temp).stream_complete(prompt, max_new_tokens=num_tokens)
398+
temp,
399+
top_p,
400+
freq_pen).stream_complete(prompt, max_new_tokens=num_tokens)
341401
perf = time.time() - start
342-
yield str(perf * 1000).split('.', 1)[0] + "ms"
402+
yield str(perf * 1000).split('.', 1)[0]
343403
gen_response = (resp.delta for resp in response)
344404
for chunk in gen_response:
345405
if "<|eot_id|>" not in chunk:
@@ -348,9 +408,9 @@ def rag_chain_streaming(prompt: str,
348408
break
349409
else:
350410
openai.api_key = os.environ.get('NVCF_RUN_KEY') if inference_mode == "cloud" else "xyz"
351-
openai.base_url = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":" + ("9999" if len(nim_model_port) == 0 else nim_model_port) + "/v1/"
411+
openai.base_url = "https://integrate.api.nvidia.com/v1/" if inference_mode == "cloud" else "http://" + nim_model_ip + ":" + ("8000" if len(nim_model_port) == 0 else nim_model_port) + "/v1/"
352412
num_nodes = 1 if ((inference_mode == "cloud" and nvcf_model_id == "playground_llama2_13b") or (inference_mode == "cloud" and nvcf_model_id == "playground_llama2_70b")) else 2
353-
start = time.time()
413+
354414
nodes = get_doc_retriever(num_nodes=num_nodes).retrieve(prompt)
355415
docs = []
356416
for node in nodes:
@@ -361,21 +421,39 @@ def rag_chain_streaming(prompt: str,
361421
prompt = LLAMA_2_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
362422
elif inference_mode == "cloud" and "mistral" in nvcf_model_id:
363423
prompt = MISTRAL_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
364-
elif inference_mode == "cloud" and "google" in nvcf_model_id:
365-
prompt = MISTRAL_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
424+
elif inference_mode == "cloud" and "microsoft" in nvcf_model_id:
425+
prompt = MICROSOFT_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
366426
else:
367-
prompt = MISTRAL_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
427+
prompt = GENERIC_RAG_TEMPLATE.format(context_str=", ".join(docs), query_str=prompt)
428+
start = time.time()
368429
completion = openai.chat.completions.create(
369-
model=nvcf_model_id if inference_mode == "cloud" else nim_model_id,
430+
model= nvcf_model_id if inference_mode == "cloud" else ("meta/llama3-8b-instruct" if len(nim_model_id) == 0 else nim_model_id),
431+
temperature=temp,
432+
top_p=top_p,
433+
# frequency_penalty=freq_pen, # Some models have yet to roll out support for these params
434+
# presence_penalty=pres_pen,
435+
messages=[{"role": "user", "content": prompt}],
436+
max_tokens=num_tokens,
437+
stream=True,
438+
) if inference_mode == "cloud" and "microsoft" in nvcf_model_id else openai.chat.completions.create(
439+
model=nvcf_model_id if inference_mode == "cloud" else ("meta/llama3-8b-instruct" if len(nim_model_id) == 0 else nim_model_id),
370440
temperature=temp,
441+
top_p=top_p,
442+
frequency_penalty=freq_pen,
443+
presence_penalty=pres_pen,
371444
messages=[{"role": "user", "content": prompt}],
372445
max_tokens=num_tokens,
373446
stream=True
374447
)
375448
perf = time.time() - start
376-
yield str(perf * 1000).split('.', 1)[0] + "ms"
449+
yield str(perf * 1000).split('.', 1)[0]
450+
377451
for chunk in completion:
378-
yield chunk.choices[0].delta.content
452+
content = chunk.choices[0].delta.content
453+
if content is not None:
454+
yield str(content)
455+
else:
456+
continue
379457

380458
def is_base64_encoded(s: str) -> bool:
381459
"""Check if a string is base64 encoded."""

code/chain_server/server.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# prestage the embedding model
3434
_ = chains.get_embedding_model()
3535
# set the global service context for Llama Index
36-
chains.set_service_context("local", "playground_mistral_7b", "10.123.45.678", 256, 0.7)
36+
chains.set_service_context("local", "playground_mistral_7b", "10.123.45.678", 256, 0.7, 1.0, 0.0)
3737

3838

3939
class Prompt(BaseModel):
@@ -50,6 +50,9 @@ class Prompt(BaseModel):
5050
nim_model_port: str
5151
nim_model_id: str
5252
temp: float
53+
top_p: float
54+
freq_pen: float
55+
pres_pen: float
5356

5457

5558
class DocumentSearch(BaseModel):
@@ -102,7 +105,10 @@ async def generate_answer(prompt: Prompt) -> StreamingResponse:
102105
prompt.nim_model_ip,
103106
prompt.nim_model_port,
104107
prompt.nim_model_id,
105-
prompt.temp)
108+
prompt.temp,
109+
prompt.top_p,
110+
prompt.freq_pen,
111+
prompt.pres_pen)
106112
return StreamingResponse(generator, media_type="text/event-stream")
107113

108114
generator = chains.llm_chain_streaming(prompt.context,
@@ -114,7 +120,10 @@ async def generate_answer(prompt: Prompt) -> StreamingResponse:
114120
prompt.nim_model_ip,
115121
prompt.nim_model_port,
116122
prompt.nim_model_id,
117-
prompt.temp)
123+
prompt.temp,
124+
prompt.top_p,
125+
prompt.freq_pen,
126+
prompt.pres_pen)
118127
return StreamingResponse(generator, media_type="text/event-stream")
119128

120129

code/chatui/chat_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def predict(
6464
nim_model_port: str,
6565
nim_model_id: str,
6666
temp_slider: float,
67+
top_p_slider: float,
68+
freq_pen_slider: float,
69+
pres_pen_slider: float,
6770
use_knowledge_base: bool,
6871
num_tokens: int
6972
) -> typing.Generator[str, None, None]:
@@ -80,6 +83,9 @@ def predict(
8083
"nim_model_port": nim_model_port,
8184
"nim_model_id": nim_model_id,
8285
"temp": temp_slider,
86+
"top_p": top_p_slider,
87+
"freq_pen": freq_pen_slider,
88+
"pres_pen": pres_pen_slider,
8389
}
8490
url = f"{self.server_url}/generate"
8591
_LOGGER.info(
@@ -110,5 +116,5 @@ def upload_documents(self, file_paths: typing.List[str]) -> None:
110116
)
111117

112118
_ = requests.post(
113-
url, headers=headers, files=files, verify=False, timeout=90 # type: ignore [arg-type]
119+
url, headers=headers, files=files, verify=False, timeout=120 # type: ignore [arg-type]
114120
) # nosec # verify=false is intentional for now

0 commit comments

Comments
 (0)