Skip to content

Commit 24d4b86

Browse files
authored
feat: add doc search endpoint (#185)
related changes: - nextcloud/context_chat#129 - nextcloud/assistant#241 --------- Signed-off-by: Anupam Kumar <[email protected]>
1 parent 9d73058 commit 24d4b86

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

context_chat_backend/chain/context.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from langchain.schema import Document
88

9+
from ..dyn_loader import VectorDBLoader
910
from ..vectordb.base import BaseVectorDB
10-
from .types import ContextException, ScopeType
11+
from .types import ContextException, ScopeType, SearchResult
1112

1213
logger = logging.getLogger('ccb.chain')
1314

@@ -39,3 +40,53 @@ def get_context_chunks(context_docs: list[Document]) -> list[str]:
3940
context_chunks.append(doc.page_content)
4041

4142
return context_chunks
43+
44+
45+
def do_doc_search(
46+
user_id: str,
47+
query: str,
48+
vectordb_loader: VectorDBLoader,
49+
ctx_limit: int = 20,
50+
scope_type: ScopeType | None = None,
51+
scope_list: list[str] | None = None,
52+
) -> list[SearchResult]:
53+
"""
54+
Raises
55+
------
56+
ContextException
57+
If the scope type is provided but the scope list is empty or not provided
58+
"""
59+
db = vectordb_loader.load()
60+
augmented_limit = ctx_limit * 2 # to account for duplicate sources
61+
docs = get_context_docs(user_id, query, db, augmented_limit, scope_type, scope_list)
62+
if len(docs) == 0:
63+
logger.warning('No documents retrieved, please index a few documents first')
64+
return []
65+
66+
sources_cache = {}
67+
results: list[SearchResult] = []
68+
for doc in docs:
69+
source_id = doc.metadata.get('source')
70+
if not source_id:
71+
logger.warning('Document without source id encountered in doc search, skipping', extra={
72+
'doc': doc,
73+
})
74+
continue
75+
if source_id in sources_cache:
76+
continue
77+
if len(results) >= ctx_limit:
78+
break
79+
80+
sources_cache[source_id] = None
81+
results.append(SearchResult(
82+
source_id=source_id,
83+
title=doc.metadata.get('title', ''),
84+
))
85+
86+
logger.debug('do_doc_search', extra={
87+
'len(docs)': len(docs),
88+
'len(results)': len(results),
89+
'scope_type': scope_type,
90+
'scope_list': scope_list,
91+
})
92+
return results

context_chat_backend/chain/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ class ContextException(Exception):
3636
class LLMOutput(TypedDict):
3737
output: str
3838
sources: list[str]
39+
# todo: add "titles" field
40+
41+
42+
class SearchResult(TypedDict):
43+
source_id: str
44+
title: str

context_chat_backend/controller.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55

66
# isort: off
7-
from .chain.types import ContextException, LLMOutput, ScopeType
7+
from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult
88
from .types import LoaderException, EmbeddingException
99
from .vectordb.types import DbException, SafeDbException, UpdateAccessOp
1010
# isort: on
@@ -26,6 +26,7 @@
2626
from nc_py_api.ex_app import persistent_storage, set_handlers
2727
from pydantic import BaseModel, ValidationInfo, field_validator
2828

29+
from .chain.context import do_doc_search
2930
from .chain.ingest.injest import embed_sources
3031
from .chain.one_shot import process_context_query, process_query
3132
from .config_parser import get_config
@@ -315,12 +316,14 @@ def _(userId: str = Body(embed=True)):
315316

316317
return JSONResponse('User deleted')
317318

319+
318320
@app.post('/countIndexedDocuments')
319321
@enabled_guard(app)
320322
def _():
321323
counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,))
322324
return JSONResponse(counts)
323325

326+
324327
@app.put('/loadSources')
325328
@enabled_guard(app)
326329
def _(sources: list[UploadFile]):
@@ -467,3 +470,17 @@ def _(query: Query) -> LLMOutput:
467470

468471
with llm_lock:
469472
return execute_query(query, in_proc=False)
473+
474+
475+
@app.post('/docSearch')
476+
@enabled_guard(app)
477+
def _(query: Query) -> list[SearchResult]:
478+
# useContext from Query is not used here
479+
return exec_in_proc(target=do_doc_search, args=(
480+
query.userId,
481+
query.query,
482+
vectordb_loader,
483+
query.ctxLimit,
484+
query.scopeType,
485+
query.scopeList,
486+
))

0 commit comments

Comments
 (0)