Skip to content

Commit 252d963

Browse files
authored
Merge pull request #116 from sandialabs:copilot/fix-rag-discovery-error
Skip RAG discovery when FEATURE_RAG_ENABLED is false
2 parents ed4ccb8 + 1da4229 commit 252d963

File tree

2 files changed

+81
-36
lines changed

2 files changed

+81
-36
lines changed

backend/routes/config_routes.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -74,42 +74,44 @@ async def get_config(
7474
# Get RAG data sources for the user (feature-gated MCP-backed discovery)
7575
rag_data_sources = []
7676
rag_servers = []
77-
try:
78-
if app_settings.feature_rag_mcp_enabled:
79-
rag_mcp = app_factory.get_rag_mcp_service()
80-
rag_data_sources = await rag_mcp.discover_data_sources(
81-
current_user, user_compliance_level=compliance_level
82-
)
83-
rag_servers = await rag_mcp.discover_servers(
84-
current_user, user_compliance_level=compliance_level
85-
)
86-
else:
87-
rag_client = app_factory.get_rag_client()
88-
# rag_client.discover_data_sources now returns List[DataSource] objects
89-
data_source_objects = await rag_client.discover_data_sources(current_user)
90-
# Convert to list of names (strings) for the 'data_sources' field (backward compatibility)
91-
rag_data_sources = [ds.name for ds in data_source_objects]
92-
# Populate rag_servers with the mock data in the expected format for the UI
93-
rag_servers = [
94-
{
95-
"server": "rag_mock",
96-
"displayName": "RAG Mock Data",
97-
"icon": "database",
98-
"complianceLevel": "Public", # Default compliance for the mock server itself
99-
"sources": [
100-
{
101-
"id": ds.name,
102-
"name": ds.name,
103-
"authRequired": True,
104-
"selected": False,
105-
"complianceLevel": ds.compliance_level,
106-
}
107-
for ds in data_source_objects
108-
],
109-
}
110-
]
111-
except Exception as e:
112-
logger.warning(f"Error resolving RAG data sources: {e}")
77+
# Only attempt RAG discovery if RAG feature is enabled
78+
if app_settings.feature_rag_enabled:
79+
try:
80+
if app_settings.feature_rag_mcp_enabled:
81+
rag_mcp = app_factory.get_rag_mcp_service()
82+
rag_data_sources = await rag_mcp.discover_data_sources(
83+
current_user, user_compliance_level=compliance_level
84+
)
85+
rag_servers = await rag_mcp.discover_servers(
86+
current_user, user_compliance_level=compliance_level
87+
)
88+
else:
89+
rag_client = app_factory.get_rag_client()
90+
# rag_client.discover_data_sources now returns List[DataSource] objects
91+
data_source_objects = await rag_client.discover_data_sources(current_user)
92+
# Convert to list of names (strings) for the 'data_sources' field (backward compatibility)
93+
rag_data_sources = [ds.name for ds in data_source_objects]
94+
# Populate rag_servers with the mock data in the expected format for the UI
95+
rag_servers = [
96+
{
97+
"server": "rag_mock",
98+
"displayName": "RAG Mock Data",
99+
"icon": "database",
100+
"complianceLevel": "Public", # Default compliance for the mock server itself
101+
"sources": [
102+
{
103+
"id": ds.name,
104+
"name": ds.name,
105+
"authRequired": True,
106+
"selected": False,
107+
"complianceLevel": ds.compliance_level,
108+
}
109+
for ds in data_source_objects
110+
],
111+
}
112+
]
113+
except Exception as e:
114+
logger.warning(f"Error resolving RAG data sources: {e}")
113115

114116
# Check if tools are enabled
115117
tools_info = []

backend/tests/test_routes_config_smoke.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from unittest.mock import patch, MagicMock, AsyncMock
2+
13
from starlette.testclient import TestClient
24

5+
from infrastructure.app_factory import app_factory
36
from main import app
47

58

@@ -14,3 +17,43 @@ def test_config_endpoint_smoke(monkeypatch):
1417
assert "tools" in data
1518
assert "prompts" in data
1619
assert "data_sources" in data
20+
21+
22+
def test_rag_discovery_skipped_when_feature_disabled(monkeypatch):
23+
"""Verify RAG discovery is not attempted when feature_rag_enabled is False."""
24+
# Create mock rag_client to track if discover_data_sources is called
25+
mock_rag_client = MagicMock()
26+
mock_rag_client.discover_data_sources = AsyncMock(return_value=[])
27+
28+
# Create mock rag_mcp_service
29+
mock_rag_mcp = MagicMock()
30+
mock_rag_mcp.discover_data_sources = AsyncMock(return_value=[])
31+
mock_rag_mcp.discover_servers = AsyncMock(return_value=[])
32+
33+
# Patch the app_factory methods
34+
with patch.object(app_factory, 'get_rag_client', return_value=mock_rag_client):
35+
with patch.object(app_factory, 'get_rag_mcp_service', return_value=mock_rag_mcp):
36+
# Ensure RAG feature is disabled
37+
config_manager = app_factory.get_config_manager()
38+
original_setting = config_manager.app_settings.feature_rag_enabled
39+
config_manager.app_settings.feature_rag_enabled = False
40+
41+
try:
42+
client = TestClient(app)
43+
resp = client.get("/api/config", headers={"X-User-Email": "[email protected]"})
44+
assert resp.status_code == 200
45+
46+
# Verify RAG discovery was NOT called when feature is disabled
47+
mock_rag_client.discover_data_sources.assert_not_called()
48+
mock_rag_mcp.discover_data_sources.assert_not_called()
49+
mock_rag_mcp.discover_servers.assert_not_called()
50+
51+
# Verify response still has data_sources field (just empty)
52+
data = resp.json()
53+
assert "data_sources" in data
54+
assert data["data_sources"] == []
55+
assert "rag_servers" in data
56+
assert data["rag_servers"] == []
57+
finally:
58+
# Restore original setting
59+
config_manager.app_settings.feature_rag_enabled = original_setting

0 commit comments

Comments
 (0)