Skip to content

Commit f5a655f

Browse files
authored
feat: Retrieval augmented generation for chat (#2886)
Resolves #2876 Depends on #2832 Dev doc: instructlab/dev-docs#161 **Checklist:** - [X] **Commit Message Formatting**: Commit titles and messages follow guidelines in the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary). - [x] [Changelog](https://github.com/instructlab/instructlab/blob/main/CHANGELOG.md) updated with breaking and/or notable changes for the next minor release. - [x] Documentation has been updated, if necessary. - [x] Unit tests have been added, if necessary. - [x] Functional tests have been added, if necessary. - [x] E2E Workflow tests have been added, if necessary. Approved-by: cdoern Approved-by: nathan-weinberg
2 parents 065c786 + e44cb78 commit f5a655f

File tree

7 files changed

+185
-49
lines changed

7 files changed

+185
-49
lines changed

src/instructlab/cli/model/chat.py

+50
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# First Party
1111
from instructlab import clickext
1212
from instructlab import configuration as cfg
13+
from instructlab.defaults import DEFAULTS
1314
from instructlab.model.chat import chat_model
1415

1516
logger = logging.getLogger(__name__)
@@ -101,6 +102,45 @@
101102
"--temperature",
102103
cls=clickext.ConfigOption,
103104
)
105+
@click.option(
106+
"--rag",
107+
"rag_enabled",
108+
default=False,
109+
is_flag=True,
110+
)
111+
@click.option(
112+
"--document-store-uri",
113+
"uri",
114+
type=click.STRING,
115+
cls=clickext.ConfigOption,
116+
config_class="rag",
117+
config_sections="document_store",
118+
)
119+
@click.option(
120+
"--document-store-collection-name",
121+
"collection_name",
122+
type=click.STRING,
123+
cls=clickext.ConfigOption,
124+
config_class="rag",
125+
config_sections="document_store",
126+
)
127+
@click.option(
128+
"--retriever-embedding-model-name",
129+
"embedding_model_name",
130+
type=click.STRING,
131+
cls=clickext.ConfigOption,
132+
config_class="rag",
133+
config_sections="embedding_model",
134+
)
135+
@click.option(
136+
"--retriever-top-k",
137+
"top_k",
138+
type=click.INT,
139+
default=DEFAULTS.RETRIEVER_TOP_K,
140+
cls=clickext.ConfigOption,
141+
config_class="rag",
142+
config_sections="retriever",
143+
)
104144
@click.pass_context
105145
@clickext.display_params
106146
def chat(
@@ -120,6 +160,11 @@ def chat(
120160
model_family,
121161
serving_log_file,
122162
temperature,
163+
rag_enabled,
164+
uri,
165+
collection_name,
166+
embedding_model_name,
167+
top_k,
123168
):
124169
"""Runs a chat using the modified model"""
125170
chat_model(
@@ -138,6 +183,11 @@ def chat(
138183
model_family,
139184
serving_log_file,
140185
temperature,
186+
rag_enabled,
187+
uri,
188+
collection_name,
189+
embedding_model_name,
190+
top_k,
141191
backend_type=ctx.obj.config.serve.server.backend_type,
142192
host=ctx.obj.config.serve.server.host,
143193
port=ctx.obj.config.serve.server.port,

src/instructlab/configuration.py

+49-48
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,31 @@ def after_debug_level(self):
127127
return self
128128

129129

130+
class _document_store(BaseModel):
131+
"""Class describing configuration of document store backend for RAG."""
132+
133+
uri: str = Field(
134+
default_factory=lambda: DEFAULTS.DEFAULT_DOCUMENT_STORE_PATH,
135+
description="Document store service URI.",
136+
)
137+
collection_name: str = Field(
138+
default=DEFAULTS.DOCUMENT_STORE_COLLECTION_NAME,
139+
description="Document store collection name.",
140+
)
141+
142+
143+
class _embedding_model(BaseModel):
144+
"""Class describing configuration of embedding parameters for RAG."""
145+
146+
# model configuration
147+
model_config = ConfigDict(extra="ignore", protected_namespaces=())
148+
149+
embedding_model_name: StrictStr = Field(
150+
default_factory=lambda: DEFAULTS.DEFAULT_EMBEDDING_MODEL,
151+
description="Embedding model to use for RAG.",
152+
)
153+
154+
130155
class _chat(BaseModel):
131156
"""Class describing configuration of the 'chat' sub-command."""
132157

@@ -285,9 +310,33 @@ class _convert(BaseModel):
285310
)
286311

287312

313+
class _retriever(BaseModel):
314+
"""Class describing configuration of retrieval parameters for RAG."""
315+
316+
top_k: int = Field(
317+
default=DEFAULTS.RETRIEVER_TOP_K,
318+
description="The maximum number of documents to retrieve.",
319+
)
320+
321+
288322
class _rag(BaseModel):
289323
"""Class describing configuration of the 'ilab rag' command."""
290324

325+
enabled: bool = Field(
326+
default=False, description="Flag for enabling RAG functionality."
327+
)
328+
document_store: _document_store = Field(
329+
default_factory=_document_store,
330+
description="Document store configuration for RAG.",
331+
)
332+
embedding_model: _embedding_model = Field(
333+
default_factory=_embedding_model,
334+
description="Embedding model configuration for RAG",
335+
)
336+
retriever: _retriever = Field(
337+
default_factory=_retriever,
338+
description="Retrieval configuration parameters for RAG",
339+
)
291340
convert: _convert = Field(
292341
default_factory=_convert, description="RAG convert configuration section."
293342
)
@@ -597,54 +646,6 @@ class _train(BaseModel):
597646
)
598647

599648

600-
class _document_store(BaseModel):
601-
"""Class describing configuration of document store backend for RAG."""
602-
603-
uri: str = Field(default="embeddings.db", description="Document store service URI.")
604-
collection_name: str = Field(
605-
default="ilab", description="Document store collection name."
606-
)
607-
608-
609-
class _embedding_model(BaseModel):
610-
"""Class describing configuration of embedding parameters for RAG."""
611-
612-
# model configuration
613-
model_config = ConfigDict(extra="ignore", protected_namespaces=())
614-
615-
model_dir: str = Field(
616-
default=DEFAULTS.MODELS_DIR,
617-
description="The default system model location store, located in the data directory.",
618-
)
619-
model_name: str = Field(
620-
default_factory=lambda: DEFAULTS.DEFAULT_EMBEDDING_MODEL,
621-
description="Embedding model to use for RAG.",
622-
)
623-
624-
def local_model_path(self) -> str:
625-
if self.model_dir is None:
626-
click.secho(f"Missing value for field model_dir in {vars(self)}")
627-
raise click.exceptions.Exit(1)
628-
629-
if self.model_name is None:
630-
click.secho(f"Missing value for field model_name in {vars(self)}")
631-
raise click.exceptions.Exit(1)
632-
633-
return os.path.join(self.model_dir, self.model_name)
634-
635-
636-
class _retriever(BaseModel):
637-
"""Class describing configuration of retrieval parameters for RAG."""
638-
639-
top_k: int = Field(
640-
default=20, description="The maximum number of documents to retrieve."
641-
)
642-
embedding_model: _embedding_model = Field(
643-
default=_embedding_model(),
644-
description="Embedding parameters for retrieval.",
645-
)
646-
647-
648649
class _metadata(BaseModel):
649650
# model configuration
650651
model_config = ConfigDict(extra="ignore")

src/instructlab/defaults.py

+7
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class _InstructlabDefaults:
8686
MISTRAL_GGUF_REPO = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
8787
GRANITE_GGUF_MODEL_NAME = "granite-7b-lab-Q4_K_M.gguf"
8888
GRANITE_EMBEDDING_MODEL_NAME = "ibm-granite/granite-embedding-125m-english"
89+
DOCUMENT_STORE_NAME = "embeddings.db"
90+
DOCUMENT_STORE_COLLECTION_NAME = "ilab"
91+
RETRIEVER_TOP_K = 20
8992
MERLINITE_GGUF_MODEL_NAME = "merlinite-7b-lab-Q4_K_M.gguf"
9093
MISTRAL_GGUF_MODEL_NAME = "mistral-7b-instruct-v0.2.Q4_K_M.gguf"
9194
MODEL_REPO = "instructlab/granite-7b-lab"
@@ -174,6 +177,10 @@ def DEFAULT_CHAT_MODEL(self) -> str:
174177
def DEFAULT_EMBEDDING_MODEL(self) -> str:
175178
return path.join(self.MODELS_DIR, self.GRANITE_EMBEDDING_MODEL_NAME)
176179

180+
@property
181+
def DEFAULT_DOCUMENT_STORE_PATH(self) -> str:
182+
return path.join(self._data_dir, self.DOCUMENT_STORE_NAME)
183+
177184
@property
178185
def DEFAULT_TEACHER_MODEL(self) -> str:
179186
return path.join(self.MODELS_DIR, self.MISTRAL_GGUF_MODEL_NAME)

src/instructlab/model/chat.py

+40
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
# Local
3535
from ..client_utils import http_client
36+
from ..rag.document_store import DocumentStoreRetriever
37+
from ..rag.document_store_factory import create_document_retriever
3638
from ..utils import get_cli_helper_sysprompt, get_model_arch, get_sysprompt
3739
from .backends import backends
3840

@@ -87,6 +89,7 @@ def __init__(
8789
self,
8890
model,
8991
client,
92+
retriever=None,
9093
vi_mode=False,
9194
prompt=True,
9295
vertical_overflow="ellipsis",
@@ -98,6 +101,7 @@ def __init__(
98101
backend_type="",
99102
):
100103
self.client = client
104+
self.retriever: DocumentStoreRetriever | None = retriever
101105
self.model = model
102106
self.vi_mode = vi_mode
103107
self.vertical_overflow = vertical_overflow
@@ -395,6 +399,13 @@ def start_prompt(
395399

396400
self.log_message(PROMPT_PREFIX + content + "\n\n")
397401

402+
# if RAG is enabled, fetch context and insert into session
403+
# TODO: what if context is already too long? note that current retriever implementation concatenates all docs
404+
# TODO: better way to check whether we should perform retrieval?
405+
if self.retriever is not None:
406+
context = self.retriever.augmented_context(user_query=content)
407+
self._update_conversation(context, "assistant")
408+
398409
# Update message history and token counters
399410
self._update_conversation(content, "user")
400411

@@ -552,6 +563,11 @@ def chat_model(
552563
model_family,
553564
serving_log_file,
554565
temperature,
566+
rag_enabled,
567+
document_store_uri,
568+
collection_name,
569+
embedding_model,
570+
top_k,
555571
backend_type,
556572
host,
557573
port,
@@ -693,6 +709,11 @@ def chat_model(
693709
max_tokens=max_tokens,
694710
max_ctx_size=max_ctx_size,
695711
temperature=temperature,
712+
rag_enabled=rag_enabled,
713+
document_store_uri=document_store_uri,
714+
collection_name=collection_name,
715+
embedding_model=embedding_model,
716+
top_k=top_k,
696717
backend_type=backend_type,
697718
params=params,
698719
)
@@ -715,6 +736,11 @@ def chat_cli(
715736
max_ctx_size,
716737
temperature,
717738
backend_type,
739+
rag_enabled,
740+
document_store_uri,
741+
collection_name,
742+
embedding_model,
743+
top_k,
718744
logs_dir,
719745
vi_mode,
720746
visible_overflow,
@@ -756,6 +782,19 @@ def chat_cli(
756782
sys_prompt = CONTEXTS.get(context, "default")(get_model_arch(pathlib.Path(model)))
757783
loaded["messages"] = [{"role": "system", "content": sys_prompt}]
758784

785+
# Instantiate retriever if RAG is enabled
786+
if rag_enabled:
787+
logger.debug("RAG enabled for chat; initializing retriever")
788+
retriever: DocumentStoreRetriever | None = create_document_retriever(
789+
document_store_uri=document_store_uri,
790+
document_store_collection_name=collection_name,
791+
top_k=top_k,
792+
embedding_model_path=embedding_model,
793+
)
794+
else:
795+
logger.debug("RAG not enabled for chat; skipping retrieval setup")
796+
retriever: DocumentStoreRetriever | None = None
797+
759798
# Session from CLI
760799
if session is not None:
761800
loaded["name"] = os.path.basename(session.name).strip(".json")
@@ -778,6 +817,7 @@ def chat_cli(
778817
ccb = ConsoleChatBot(
779818
model if model is None else model,
780819
client=client,
820+
retriever=retriever,
781821
vi_mode=vi_mode,
782822
log_file=log_file,
783823
prompt=not qq,

tests/test_lab.py

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def has_debug_params(self) -> bool:
115115
Command(("config", "show")),
116116
Command(("model",), needs_config=False, should_fail=False),
117117
Command(("model", "chat")),
118+
Command(("model", "chat"), ("--rag",)),
118119
Command(("model", "convert"), ("--model-dir", "test")),
119120
Command(("model", "download")),
120121
Command(("model", "evaluate"), ("--benchmark", "mmlu")),

tests/test_model_chat.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Standard
22
from unittest.mock import MagicMock
33
import contextlib
4+
import logging
45
import re
56

67
# Third Party
78
from rich.console import Console
89
import pytest
910

1011
# First Party
11-
from instructlab.model.chat import ConsoleChatBot
12+
from instructlab.model.chat import ChatException, ConsoleChatBot
13+
14+
logger = logging.getLogger(__name__)
1215

1316

1417
@pytest.mark.parametrize(
@@ -24,6 +27,19 @@ def test_model_name(model_path, expected_name):
2427
assert chatbot.model_name == expected_name
2528

2629

30+
def test_retriever_is_called_when_present():
31+
retriever = MagicMock()
32+
chatbot = ConsoleChatBot(
33+
model="/var/model/file", client=None, retriever=retriever, loaded={}
34+
)
35+
assert chatbot.retriever == retriever
36+
user_query = "test"
37+
with pytest.raises(ChatException) as exc_info:
38+
chatbot.start_prompt(content=user_query, logger=logger)
39+
logger.info(exc_info)
40+
retriever.augmented_context.assert_called_with(user_query=user_query)
41+
42+
2743
def handle_output(output):
2844
return re.sub(r"\s+", " ", output).strip()
2945

tests/testdata/default_config.yaml

+21
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,27 @@ rag:
239239
# Directory where taxonomy is stored and accessed from.
240240
# Default: /data/instructlab/taxonomy
241241
taxonomy_path: /data/instructlab/taxonomy
242+
# Document store configuration for RAG.
243+
document_store:
244+
# Document store collection name.
245+
# Default: ilab
246+
collection_name: ilab
247+
# Document store service URI.
248+
# Default: /data/instructlab/embeddings.db
249+
uri: /data/instructlab/embeddings.db
250+
# Embedding model configuration for RAG
251+
embedding_model:
252+
# Embedding model to use for RAG.
253+
# Default: /cache/instructlab/models/ibm-granite/granite-embedding-125m-english
254+
embedding_model_name: /cache/instructlab/models/ibm-granite/granite-embedding-125m-english
255+
# Flag for enabling RAG functionality.
256+
# Default: False
257+
enabled: false
258+
# Retrieval configuration parameters for RAG
259+
retriever:
260+
# The maximum number of documents to retrieve.
261+
# Default: 20
262+
top_k: 20
242263
# Serve configuration section.
243264
serve:
244265
# Serving backend to use to host the model.

0 commit comments

Comments
 (0)