Skip to content

Commit eaf9496

Browse files
authored
Merge pull request #708
Feat/add relevant chunk support to colang 2
2 parents 942abd2 + 9c93d1b commit eaf9496

File tree

5 files changed

+122
-0
lines changed

5 files changed

+122
-0
lines changed

Diff for: nemoguardrails/colang/v2_x/library/core.co

+15
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ flow _user_said $text -> $event
1111
else
1212
match UtteranceUserAction.Finished() as $event
1313

14+
$text = $event.final_transcript
15+
1416
flow _user_saying $text -> $event
1517
"""The internal flow for all semantic 'user saying' flows."""
18+
1619
if $text
1720
if is_regex($text)
1821
match UtteranceUserAction.TranscriptUpdated(interim_transcript=$text) as $event
@@ -22,10 +25,16 @@ flow _user_saying $text -> $event
2225
else
2326
match UtteranceUserAction.TranscriptUpdated() as $event
2427

28+
$text = $event.interim_transcript
29+
2530
flow _user_said_something_unexpected -> $event
2631
"""The internal flow for all semantic 'user said something unexpected' flows."""
32+
global $last_user_message
2733
match UnhandledEvent(event="UtteranceUserActionFinished", loop_ids={$self.loop_id}) as $event
2834

35+
$text = $event.final_transcript
36+
$last_user_message = $text
37+
2938
@meta(user_action='user said "{$transcript}"')
3039
flow user said $text -> $transcript
3140
"""Wait for a user to have said given text."""
@@ -73,6 +82,12 @@ flow user said something unexpected -> $event, $transcript
7382

7483
flow _bot_say $text -> $action
7584
"""The internal flow for all semantic level bot utterance flows."""
85+
global $bot_message
86+
global $last_bot_message
87+
88+
$bot_message = $text
89+
$last_bot_message = $text
90+
7691
await UtteranceBotAction(script=$text) as $action
7792

7893
@meta(bot_action=True)

Diff for: nemoguardrails/colang/v2_x/library/guardrails.co

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ flow _user_saying $text -> $event
4343
flow _user_said_something_unexpected -> $event
4444
"""Override core flow for when the user said something unexpected."""
4545
global $user_message
46+
global $last_user_message
4647
match UnhandledEvent(event="UtteranceUserActionFinished", loop_ids={$self.loop_id}) as $event
4748

4849
$text = $event.final_transcript

Diff for: nemoguardrails/colang/v2_x/library/llm.co

+14
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ flow continuation on unhandled user utterance
9595

9696
log 'start generating user intent and bot intent/action...'
9797
$action = 'user said "{$event.final_transcript}"'
98+
99+
100+
# retrieve relevant chunks from KB if user_message is not empty
101+
102+
await RetrieveRelevantChunksAction()
103+
104+
98105
#await GenerateUserIntentAction(user_action=$action, max_example_flows=20) as $action_ref
99106
#$user_intent = $action_ref.return_value
100107
await GenerateUserIntentAndBotAction(user_action=$action, max_example_flows=20) as $action_ref
@@ -205,11 +212,18 @@ flow llm generate interaction continuation flow -> $flow_name
205212
activate polling llm request response
206213
# Generate continuation based current interaction history
207214

215+
216+
# retrieve relevant chunks from KB if user_message is not empty
217+
await RetrieveRelevantChunksAction()
218+
219+
208220
log 'start generating flow continuation...'
221+
209222
$flow_info = await GenerateFlowContinuationAction(temperature=0.1)
210223
log "generated flow continuation: `{$flow_info}`"
211224
$exists = await CheckValidFlowExistsAction(flow_id=$flow_info.name)
212225

226+
213227
if $exists == False
214228
$flows = await AddFlowsAction(config=$flow_info.body)
215229
if len($flows) == 0

Diff for: nemoguardrails/llm/prompts/general.yml

+15
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ prompts:
146146
# These are the most likely user intents:
147147
{{ examples }}
148148
149+
{% if context.relevant_chunks %}
150+
# This is some additional context:
151+
```markdown
152+
{{ context.relevant_chunks }}
153+
```
154+
{% endif %}
155+
149156
# This is the current conversation between the user and the bot:
150157
{{ history | colang }}
151158
@@ -205,6 +212,14 @@ prompts:
205212
# This is the current conversation between the user and the bot:
206213
{{ history | colang }}
207214
215+
{% if context.relevant_chunks %}
216+
# This is some additional context:
217+
```markdown
218+
{{ context.relevant_chunks }}
219+
```
220+
{% endif %}
221+
222+
208223
bot intent:
209224
210225
- task: generate_flow_continuation_from_flow_nld

Diff for: tests/test_retrieve_relevant_chunks.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from unittest.mock import MagicMock
16+
17+
import pytest
18+
19+
from nemoguardrails import LLMRails, RailsConfig
20+
from nemoguardrails.kb.kb import KnowledgeBase
21+
from tests.utils import TestChat
22+
23+
config = RailsConfig.from_content(
24+
"""
25+
import llm
26+
import core
27+
28+
flow main
29+
activate llm continuation
30+
31+
flow user express greeting
32+
user said "hello"
33+
or user said "hi"
34+
or user said "how are you"
35+
36+
flow bot express greeting
37+
bot say "Hey!"
38+
39+
flow greeting
40+
user express greeting
41+
bot express greeting
42+
""",
43+
yaml_content="""
44+
colang_version: 2.x
45+
models: []
46+
""",
47+
)
48+
49+
50+
def test_relevant_chunk_inserted_in_prompt():
51+
mock_kb = MagicMock(spec=KnowledgeBase)
52+
53+
mock_kb.search_relevant_chunks.return_value = [
54+
{"title": "Test Title", "body": "Test Body"}
55+
]
56+
57+
chat = TestChat(
58+
config,
59+
llm_completions=[
60+
" user express greeting",
61+
' bot respond to aditional context\nbot action: "Hello is there anything else" ',
62+
],
63+
)
64+
65+
rails = chat.app
66+
67+
rails.runtime.register_action_param("kb", mock_kb)
68+
69+
messages = [
70+
{"role": "user", "content": "Hi!"},
71+
]
72+
73+
new_message = rails.generate(messages=messages)
74+
75+
info = rails.explain()
76+
assert len(info.llm_calls) == 2
77+
assert "Test Body" in info.llm_calls[1].prompt

0 commit comments

Comments
 (0)