Skip to content

Commit 194a23b

Browse files
authored
Merge pull request #121 from sandialabs/copilot/add-doe-check-feature-flag
Add generic email domain whitelist middleware with configuration file
2 parents 0f1fba9 + 669044f commit 194a23b

File tree

8 files changed

+777
-0
lines changed

8 files changed

+777
-0
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ FEATURE_FILES_PANEL_ENABLED=true # Uploaded/session files panel
7777
FEATURE_CHAT_HISTORY_ENABLED=false # Previous chat history list
7878
FEATURE_COMPLIANCE_LEVELS_ENABLED=false # Compliance level filtering for MCP servers and data sources
7979
FEATURE_SPLASH_SCREEN_ENABLED=false # Startup splash screen for displaying policies and information
80+
FEATURE_DOMAIN_WHITELIST_ENABLED=false # Restrict access to whitelisted email domains (config/defaults/domain-whitelist.json)
8081

8182
# (Adjust above to stage rollouts. For a bare-bones chat set them all to false.)
8283

backend/core/domain_whitelist.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
Domain whitelist management for email access control.
3+
4+
Loads domain whitelist definitions from domain-whitelist.json and provides
5+
validation for user email domains.
6+
"""
7+
8+
import json
9+
import logging
10+
from pathlib import Path
11+
from typing import Optional, Set
12+
from dataclasses import dataclass
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@dataclass
18+
class DomainWhitelistConfig:
19+
"""Configuration for domain whitelist."""
20+
enabled: bool
21+
domains: Set[str]
22+
subdomain_matching: bool
23+
version: str
24+
description: str
25+
26+
27+
class DomainWhitelistManager:
28+
"""Manages domain whitelist configuration and validation."""
29+
30+
def __init__(self, config_path: Optional[Path] = None):
31+
"""Initialize the domain whitelist manager.
32+
33+
Args:
34+
config_path: Path to domain-whitelist.json. If None, uses default location.
35+
"""
36+
self.config: Optional[DomainWhitelistConfig] = None
37+
38+
if config_path is None:
39+
# Try to find config in standard locations
40+
backend_root = Path(__file__).parent.parent
41+
project_root = backend_root.parent
42+
43+
search_paths = [
44+
project_root / "config" / "overrides" / "domain-whitelist.json",
45+
project_root / "config" / "defaults" / "domain-whitelist.json",
46+
backend_root / "configfilesadmin" / "domain-whitelist.json",
47+
backend_root / "configfiles" / "domain-whitelist.json",
48+
]
49+
50+
for path in search_paths:
51+
if path.exists():
52+
config_path = path
53+
break
54+
55+
if config_path and config_path.exists():
56+
self._load_config(config_path)
57+
else:
58+
logger.warning("No domain-whitelist.json found, domain whitelist disabled")
59+
self.config = DomainWhitelistConfig(
60+
enabled=False,
61+
domains=set(),
62+
subdomain_matching=True,
63+
version="1.0",
64+
description="No config loaded"
65+
)
66+
67+
def _load_config(self, config_path: Path):
68+
"""Load domain whitelist configuration from JSON file."""
69+
try:
70+
with open(config_path, 'r', encoding='utf-8') as f:
71+
config_data = json.load(f)
72+
73+
# Extract domains from the list of domain objects
74+
domains = set()
75+
for domain_entry in config_data.get('domains', []):
76+
if isinstance(domain_entry, dict):
77+
domains.add(domain_entry.get('domain', '').lower())
78+
elif isinstance(domain_entry, str):
79+
domains.add(domain_entry.lower())
80+
81+
self.config = DomainWhitelistConfig(
82+
enabled=config_data.get('enabled', False),
83+
domains=domains,
84+
subdomain_matching=config_data.get('subdomain_matching', True),
85+
version=config_data.get('version', '1.0'),
86+
description=config_data.get('description', '')
87+
)
88+
89+
logger.info(f"Loaded {len(self.config.domains)} domains from {config_path}")
90+
logger.debug(f"Domain whitelist enabled: {self.config.enabled}")
91+
92+
except Exception as e:
93+
logger.error(f"Error loading domain-whitelist.json: {e}")
94+
# Use disabled config on error
95+
self.config = DomainWhitelistConfig(
96+
enabled=False,
97+
domains=set(),
98+
subdomain_matching=True,
99+
version="1.0",
100+
description="Error loading config"
101+
)
102+
103+
def is_enabled(self) -> bool:
104+
"""Check if domain whitelist is enabled.
105+
106+
Returns:
107+
True if enabled, False otherwise
108+
"""
109+
return self.config is not None and self.config.enabled
110+
111+
def is_domain_allowed(self, email: str) -> bool:
112+
"""Check if an email address is from an allowed domain.
113+
114+
Args:
115+
email: Email address to validate
116+
117+
Returns:
118+
True if domain is allowed, False otherwise
119+
"""
120+
if not self.config or not self.config.enabled:
121+
# If not enabled or no config, allow all
122+
return True
123+
124+
if not email or "@" not in email:
125+
return False
126+
127+
domain = email.split("@", 1)[1].lower()
128+
129+
# Check if domain is in whitelist (O(1) lookup)
130+
if domain in self.config.domains:
131+
return True
132+
133+
# Check subdomains if enabled - check each parent level
134+
if self.config.subdomain_matching:
135+
# Split domain and check each parent level
136+
# e.g., for "mail.dept.sandia.gov" check: "dept.sandia.gov", "sandia.gov"
137+
parts = domain.split(".")
138+
for i in range(1, len(parts)):
139+
parent_domain = ".".join(parts[i:])
140+
if parent_domain in self.config.domains:
141+
return True
142+
143+
return False
144+
145+
def get_domains(self) -> Set[str]:
146+
"""Get the set of whitelisted domains.
147+
148+
Returns:
149+
Set of allowed domains
150+
"""
151+
return self.config.domains if self.config else set()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Email domain whitelist validation middleware.
2+
3+
This middleware enforces that users must have email addresses from whitelisted
4+
domains. Configuration is loaded from domain-whitelist.json and can be
5+
enabled/disabled via the FEATURE_DOMAIN_WHITELIST_ENABLED feature flag.
6+
"""
7+
8+
import logging
9+
from fastapi import Request
10+
from starlette.middleware.base import BaseHTTPMiddleware
11+
from starlette.responses import JSONResponse, RedirectResponse, Response
12+
13+
from core.domain_whitelist import DomainWhitelistManager
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class DomainWhitelistMiddleware(BaseHTTPMiddleware):
19+
"""Middleware to enforce email domain whitelist restrictions."""
20+
21+
def __init__(self, app, auth_redirect_url: str = "/auth"):
22+
"""Initialize domain whitelist middleware.
23+
24+
Args:
25+
app: ASGI application
26+
auth_redirect_url: URL to redirect to on auth failure (default: /auth)
27+
"""
28+
super().__init__(app)
29+
self.auth_redirect_url = auth_redirect_url
30+
self.whitelist_manager = DomainWhitelistManager()
31+
32+
if self.whitelist_manager.is_enabled():
33+
logger.info(f"Domain whitelist enabled with {len(self.whitelist_manager.get_domains())} domains")
34+
else:
35+
logger.info("Domain whitelist disabled")
36+
37+
async def dispatch(self, request: Request, call_next) -> Response:
38+
"""Check if user email is from a whitelisted domain.
39+
40+
Args:
41+
request: Incoming HTTP request
42+
call_next: Next middleware/handler in chain
43+
44+
Returns:
45+
Response from next handler if authorized, or 403/redirect if not
46+
"""
47+
# Skip check for health endpoint and auth redirect endpoint
48+
if request.url.path == '/api/health' or request.url.path == self.auth_redirect_url:
49+
return await call_next(request)
50+
51+
# If whitelist is not enabled in config, allow all
52+
if not self.whitelist_manager.is_enabled():
53+
return await call_next(request)
54+
55+
# Get email from request state (set by AuthMiddleware)
56+
email = getattr(request.state, "user_email", None)
57+
58+
if not email or "@" not in email:
59+
logger.warning("Domain whitelist check failed: missing or invalid email")
60+
return self._unauthorized_response(request, "User email required")
61+
62+
# Check if domain is allowed
63+
if not self.whitelist_manager.is_domain_allowed(email):
64+
domain = email.split("@", 1)[1].lower()
65+
logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}")
66+
return self._unauthorized_response(
67+
request,
68+
"Access restricted to whitelisted domains"
69+
)
70+
71+
return await call_next(request)
72+
73+
def _unauthorized_response(self, request: Request, detail: str) -> Response:
74+
"""Return appropriate unauthorized response based on endpoint type.
75+
76+
Args:
77+
request: Incoming HTTP request
78+
detail: Error detail message
79+
80+
Returns:
81+
JSONResponse for API endpoints, RedirectResponse for others
82+
"""
83+
if request.url.path.startswith('/api/'):
84+
return JSONResponse(
85+
status_code=403,
86+
content={"detail": detail}
87+
)
88+
return RedirectResponse(url=self.auth_redirect_url, status_code=302)

backend/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from core.middleware import AuthMiddleware
2828
from core.rate_limit_middleware import RateLimitMiddleware
2929
from core.security_headers_middleware import SecurityHeadersMiddleware
30+
from core.domain_whitelist_middleware import DomainWhitelistMiddleware
3031
from core.otel_config import setup_opentelemetry
3132
from core.utils import sanitize_for_logging
3233
from core.auth import get_user_from_header
@@ -132,6 +133,12 @@ async def lifespan(app: FastAPI):
132133
"""
133134
app.add_middleware(SecurityHeadersMiddleware)
134135
app.add_middleware(RateLimitMiddleware)
136+
# Domain whitelist check (if enabled) - add before Auth so it runs after
137+
if config.app_settings.feature_domain_whitelist_enabled:
138+
app.add_middleware(
139+
DomainWhitelistMiddleware,
140+
auth_redirect_url=config.app_settings.auth_redirect_url
141+
)
135142
app.add_middleware(
136143
AuthMiddleware,
137144
debug_mode=config.app_settings.debug_mode,

backend/modules/config/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ def agent_mode_available(self) -> bool:
285285
description="Enable compliance level filtering for MCP servers and data sources",
286286
validation_alias=AliasChoices("FEATURE_COMPLIANCE_LEVELS_ENABLED"),
287287
)
288+
# Email domain whitelist feature gate
289+
feature_domain_whitelist_enabled: bool = Field(
290+
False,
291+
description="Enable email domain whitelist restriction (configured in domain-whitelist.json)",
292+
validation_alias=AliasChoices("FEATURE_DOMAIN_WHITELIST_ENABLED", "FEATURE_DOE_LAB_CHECK_ENABLED"),
293+
)
288294

289295
# Capability tokens (for headless access to downloads/iframes)
290296
capability_token_secret: str = ""

0 commit comments

Comments
 (0)