@@ -127,6 +127,31 @@ def after_debug_level(self):
127
127
return self
128
128
129
129
130
+ class _document_store (BaseModel ):
131
+ """Class describing configuration of document store backend for RAG."""
132
+
133
+ uri : str = Field (
134
+ default_factory = lambda : DEFAULTS .DEFAULT_DOCUMENT_STORE_PATH ,
135
+ description = "Document store service URI." ,
136
+ )
137
+ collection_name : str = Field (
138
+ default = DEFAULTS .DOCUMENT_STORE_COLLECTION_NAME ,
139
+ description = "Document store collection name." ,
140
+ )
141
+
142
+
143
+ class _embedding_model (BaseModel ):
144
+ """Class describing configuration of embedding parameters for RAG."""
145
+
146
+ # model configuration
147
+ model_config = ConfigDict (extra = "ignore" , protected_namespaces = ())
148
+
149
+ embedding_model_name : StrictStr = Field (
150
+ default_factory = lambda : DEFAULTS .DEFAULT_EMBEDDING_MODEL ,
151
+ description = "Embedding model to use for RAG." ,
152
+ )
153
+
154
+
130
155
class _chat (BaseModel ):
131
156
"""Class describing configuration of the 'chat' sub-command."""
132
157
@@ -285,9 +310,33 @@ class _convert(BaseModel):
285
310
)
286
311
287
312
313
+ class _retriever (BaseModel ):
314
+ """Class describing configuration of retrieval parameters for RAG."""
315
+
316
+ top_k : int = Field (
317
+ default = DEFAULTS .RETRIEVER_TOP_K ,
318
+ description = "The maximum number of documents to retrieve." ,
319
+ )
320
+
321
+
288
322
class _rag (BaseModel ):
289
323
"""Class describing configuration of the 'ilab rag' command."""
290
324
325
+ enabled : bool = Field (
326
+ default = False , description = "Flag for enabling RAG functionality."
327
+ )
328
+ document_store : _document_store = Field (
329
+ default_factory = _document_store ,
330
+ description = "Document store configuration for RAG." ,
331
+ )
332
+ embedding_model : _embedding_model = Field (
333
+ default_factory = _embedding_model ,
334
+ description = "Embedding model configuration for RAG" ,
335
+ )
336
+ retriever : _retriever = Field (
337
+ default_factory = _retriever ,
338
+ description = "Retrieval configuration parameters for RAG" ,
339
+ )
291
340
convert : _convert = Field (
292
341
default_factory = _convert , description = "RAG convert configuration section."
293
342
)
@@ -597,54 +646,6 @@ class _train(BaseModel):
597
646
)
598
647
599
648
600
- class _document_store (BaseModel ):
601
- """Class describing configuration of document store backend for RAG."""
602
-
603
- uri : str = Field (default = "embeddings.db" , description = "Document store service URI." )
604
- collection_name : str = Field (
605
- default = "ilab" , description = "Document store collection name."
606
- )
607
-
608
-
609
- class _embedding_model (BaseModel ):
610
- """Class describing configuration of embedding parameters for RAG."""
611
-
612
- # model configuration
613
- model_config = ConfigDict (extra = "ignore" , protected_namespaces = ())
614
-
615
- model_dir : str = Field (
616
- default = DEFAULTS .MODELS_DIR ,
617
- description = "The default system model location store, located in the data directory." ,
618
- )
619
- model_name : str = Field (
620
- default_factory = lambda : DEFAULTS .DEFAULT_EMBEDDING_MODEL ,
621
- description = "Embedding model to use for RAG." ,
622
- )
623
-
624
- def local_model_path (self ) -> str :
625
- if self .model_dir is None :
626
- click .secho (f"Missing value for field model_dir in { vars (self )} " )
627
- raise click .exceptions .Exit (1 )
628
-
629
- if self .model_name is None :
630
- click .secho (f"Missing value for field model_name in { vars (self )} " )
631
- raise click .exceptions .Exit (1 )
632
-
633
- return os .path .join (self .model_dir , self .model_name )
634
-
635
-
636
- class _retriever (BaseModel ):
637
- """Class describing configuration of retrieval parameters for RAG."""
638
-
639
- top_k : int = Field (
640
- default = 20 , description = "The maximum number of documents to retrieve."
641
- )
642
- embedding_model : _embedding_model = Field (
643
- default = _embedding_model (),
644
- description = "Embedding parameters for retrieval." ,
645
- )
646
-
647
-
648
649
class _metadata (BaseModel ):
649
650
# model configuration
650
651
model_config = ConfigDict (extra = "ignore" )
0 commit comments