Skip to content

Commit 9f69469

Browse files
Model allowlist and blocklists (#446)
* implement model allow/blocklist in UI * skip extension tests due to local flakiness * implement model allow/blocklists in config manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 831c42b commit 9f69469

File tree

5 files changed

+253
-50
lines changed

5 files changed

+253
-50
lines changed

packages/jupyter-ai/jupyter_ai/config_manager.py

Lines changed: 90 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import shutil
55
import time
6-
from typing import Optional, Union
6+
from typing import List, Optional, Union
77

88
from deepmerge import always_merger as Merger
99
from jsonschema import Draft202012Validator as Validator
@@ -12,10 +12,8 @@
1212
AnyProvider,
1313
EmProvidersDict,
1414
LmProvidersDict,
15-
ProviderRestrictions,
1615
get_em_provider,
1716
get_lm_provider,
18-
is_provider_allowed,
1917
)
2018
from jupyter_core.paths import jupyter_data_dir
2119
from traitlets import Integer, Unicode
@@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
5755
pass
5856

5957

58+
class BlockedModelError(Exception):
59+
pass
60+
61+
6062
def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
6163
# TODO: handle non-env auth strategies
6264
if not provider.auth_strategy or provider.auth_strategy.type != "env":
@@ -99,27 +101,34 @@ def __init__(
99101
log: Logger,
100102
lm_providers: LmProvidersDict,
101103
em_providers: EmProvidersDict,
102-
restrictions: ProviderRestrictions,
104+
allowed_providers: Optional[List[str]],
105+
blocked_providers: Optional[List[str]],
106+
allowed_models: Optional[List[str]],
107+
blocked_models: Optional[List[str]],
103108
*args,
104109
**kwargs,
105110
):
106111
super().__init__(*args, **kwargs)
107112
self.log = log
108-
"""List of LM providers."""
113+
109114
self._lm_providers = lm_providers
110-
"""List of EM providers."""
115+
"""List of LM providers."""
111116
self._em_providers = em_providers
112-
"""Provider restrictions."""
113-
self._restrictions = restrictions
117+
"""List of EM providers."""
118+
119+
self._allowed_providers = allowed_providers
120+
self._blocked_providers = blocked_providers
121+
self._allowed_models = allowed_models
122+
self._blocked_models = blocked_models
114123

124+
self._last_read: Optional[int] = None
115125
"""When the server last read the config file. If the file was not
116126
modified after this time, then we can return the cached
117127
`self._config`."""
118-
self._last_read: Optional[int] = None
119128

129+
self._config: Optional[GlobalConfig] = None
120130
"""In-memory cache of the `GlobalConfig` object parsed from the config
121131
file."""
122-
self._config: Optional[GlobalConfig] = None
123132

124133
self._init_config_schema()
125134
self._init_validator()
@@ -140,6 +149,26 @@ def _init_config(self):
140149
if os.path.exists(self.config_path):
141150
with open(self.config_path, encoding="utf-8") as f:
142151
config = GlobalConfig(**json.loads(f.read()))
152+
lm_id = config.model_provider_id
153+
em_id = config.embeddings_provider_id
154+
155+
# if the currently selected language or embedding model are
156+
# forbidden, set them to `None` and log a warning.
157+
if lm_id is not None and not self._validate_model(
158+
lm_id, raise_exc=False
159+
):
160+
self.log.warning(
161+
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
162+
)
163+
config.model_provider_id = None
164+
if em_id is not None and not self._validate_model(
165+
em_id, raise_exc=False
166+
):
167+
self.log.warning(
168+
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
169+
)
170+
config.embeddings_provider_id = None
171+
143172
# re-write to the file to validate the config and apply any
144173
# updates to the config file immediately
145174
self._write_config(config)
@@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
181210
_, lm_provider = get_lm_provider(
182211
config.model_provider_id, self._lm_providers
183212
)
184-
# do not check config for blocked providers
185-
if not is_provider_allowed(config.model_provider_id, self._restrictions):
186-
assert not lm_provider
187-
return
213+
214+
# verify model is declared by some provider
188215
if not lm_provider:
189216
raise ValueError(
190217
f"No language model is associated with '{config.model_provider_id}'."
191218
)
219+
220+
# verify model is not blocked
221+
self._validate_model(config.model_provider_id)
222+
223+
# verify model is authenticated
192224
_validate_provider_authn(config, lm_provider)
193225

194226
# validate embedding model config
195227
if config.embeddings_provider_id:
196228
_, em_provider = get_em_provider(
197229
config.embeddings_provider_id, self._em_providers
198230
)
199-
# do not check config for blocked providers
200-
if not is_provider_allowed(
201-
config.embeddings_provider_id, self._restrictions
202-
):
203-
assert not em_provider
204-
return
231+
232+
# verify model is declared by some provider
205233
if not em_provider:
206234
raise ValueError(
207235
f"No embedding model is associated with '{config.embeddings_provider_id}'."
208236
)
237+
238+
# verify model is not blocked
239+
self._validate_model(config.embeddings_provider_id)
240+
241+
# verify model is authenticated
209242
_validate_provider_authn(config, em_provider)
210243

244+
def _validate_model(self, model_id: str, raise_exc=True):
245+
"""
246+
Validates a model against the set of allow/blocklists specified by the
247+
traitlets configuration, returning `True` if the model is allowed, and
248+
raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
249+
function returns `False` if the model is not allowed.
250+
"""
251+
252+
assert model_id is not None
253+
components = model_id.split(":", 1)
254+
assert len(components) == 2
255+
provider_id, _ = components
256+
257+
try:
258+
if self._allowed_providers and provider_id not in self._allowed_providers:
259+
raise BlockedModelError(
260+
"Model provider not included in the provider allowlist."
261+
)
262+
263+
if self._blocked_providers and provider_id in self._blocked_providers:
264+
raise BlockedModelError(
265+
"Model provider included in the provider blocklist."
266+
)
267+
268+
if self._allowed_models and model_id not in self._allowed_models:
269+
raise BlockedModelError("Model not included in the model allowlist.")
270+
271+
if self._blocked_models and model_id in self._blocked_models:
272+
raise BlockedModelError("Model included in the model blocklist.")
273+
except BlockedModelError as e:
274+
if raise_exc:
275+
raise e
276+
else:
277+
return False
278+
279+
return True
280+
211281
def _write_config(self, new_config: GlobalConfig):
212282
"""Updates configuration and persists it to disk. This accepts a
213283
complete `GlobalConfig` object, and should not be called publicly."""

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,50 @@ class AiExtension(ExtensionApp):
5353
config=True,
5454
)
5555

56+
allowed_models = List(
57+
Unicode(),
58+
default_value=None,
59+
help="""
60+
Language models to allow, as a list of global model IDs in the format
61+
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
62+
`None`.
63+
64+
Note: Currently, if `allowed_providers` is also set, then this field is
65+
ignored. This is subject to change in a future non-major release. Using
66+
both traits is considered to be undefined behavior at this time.
67+
""",
68+
allow_none=True,
69+
config=True,
70+
)
71+
72+
blocked_models = List(
73+
Unicode(),
74+
default_value=None,
75+
help="""
76+
Language models to block, as a list of global model IDs in the format
77+
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
78+
`None`.
79+
""",
80+
allow_none=True,
81+
config=True,
82+
)
83+
5684
def initialize_settings(self):
5785
start = time.time()
86+
87+
# Read from allowlist and blocklist
5888
restrictions = {
5989
"allowed_providers": self.allowed_providers,
6090
"blocked_providers": self.blocked_providers,
6191
}
62-
92+
self.settings["allowed_models"] = self.allowed_models
93+
self.settings["blocked_models"] = self.blocked_models
94+
self.log.info(f"Configured provider allowlist: {self.allowed_providers}")
95+
self.log.info(f"Configured provider blocklist: {self.blocked_providers}")
96+
self.log.info(f"Configured model allowlist: {self.allowed_models}")
97+
self.log.info(f"Configured model blocklist: {self.blocked_models}")
98+
99+
# Fetch LM & EM providers
63100
self.settings["lm_providers"] = get_lm_providers(
64101
log=self.log, restrictions=restrictions
65102
)
@@ -73,7 +110,10 @@ def initialize_settings(self):
73110
log=self.log,
74111
lm_providers=self.settings["lm_providers"],
75112
em_providers=self.settings["em_providers"],
76-
restrictions=restrictions,
113+
allowed_providers=self.allowed_providers,
114+
blocked_providers=self.blocked_providers,
115+
allowed_models=self.allowed_models,
116+
blocked_models=self.blocked_models,
77117
)
78118

79119
self.log.info("Registered providers.")

packages/jupyter-ai/jupyter_ai/handlers.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from asyncio import AbstractEventLoop
66
from dataclasses import asdict
7-
from typing import TYPE_CHECKING, Dict, List
7+
from typing import TYPE_CHECKING, Dict, List, Optional
88

99
import tornado
1010
from jupyter_ai.chat_handlers import BaseChatHandler
@@ -240,14 +240,58 @@ def on_close(self):
240240
self.log.debug("Chat clients: %s", self.root_chat_handlers.keys())
241241

242242

243-
class ModelProviderHandler(BaseAPIHandler):
243+
class ProviderHandler(BaseAPIHandler):
244+
"""
245+
Helper base class used for HTTP handlers hosting endpoints relating to
246+
providers. Wrapper around BaseAPIHandler.
247+
"""
248+
244249
@property
245250
def lm_providers(self) -> Dict[str, "BaseProvider"]:
246251
return self.settings["lm_providers"]
247252

253+
@property
254+
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
255+
return self.settings["em_providers"]
256+
257+
@property
258+
def allowed_models(self) -> Optional[List[str]]:
259+
return self.settings["allowed_models"]
260+
261+
@property
262+
def blocked_models(self) -> Optional[List[str]]:
263+
return self.settings["blocked_models"]
264+
265+
def _filter_blocked_models(self, providers: List[ListProvidersEntry]):
266+
"""
267+
Satisfy the model-level allow/blocklist by filtering models accordingly.
268+
The provider-level allow/blocklist is already handled in
269+
`AiExtension.initialize_settings()`.
270+
"""
271+
if self.blocked_models is None and self.allowed_models is None:
272+
return providers
273+
274+
def filter_predicate(local_model_id: str):
275+
model_id = provider.id + ":" + local_model_id
276+
if self.blocked_models:
277+
return model_id not in self.blocked_models
278+
else:
279+
return model_id in self.allowed_models
280+
281+
# filter out every model w/ model ID according to allow/blocklist
282+
for provider in providers:
283+
provider.models = list(filter(filter_predicate, provider.models))
284+
285+
# filter out every provider with no models which satisfy the allow/blocklist, then return
286+
return filter((lambda p: len(p.models) > 0), providers)
287+
288+
289+
class ModelProviderHandler(ProviderHandler):
248290
@web.authenticated
249291
def get(self):
250292
providers = []
293+
294+
# Step 1: gather providers
251295
for provider in self.lm_providers.values():
252296
# skip old legacy OpenAI chat provider used only in magics
253297
if provider.id == "openai-chat":
@@ -270,17 +314,16 @@ def get(self):
270314
)
271315
)
272316

273-
response = ListProvidersResponse(
274-
providers=sorted(providers, key=lambda p: p.name)
275-
)
276-
self.finish(response.json())
317+
# Step 2: sort & filter providers
318+
providers = self._filter_blocked_models(providers)
319+
providers = sorted(providers, key=lambda p: p.name)
277320

321+
# Finally, yield response.
322+
response = ListProvidersResponse(providers=providers)
323+
self.finish(response.json())
278324

279-
class EmbeddingsModelProviderHandler(BaseAPIHandler):
280-
@property
281-
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
282-
return self.settings["em_providers"]
283325

326+
class EmbeddingsModelProviderHandler(ProviderHandler):
284327
@web.authenticated
285328
def get(self):
286329
providers = []
@@ -296,9 +339,10 @@ def get(self):
296339
)
297340
)
298341

299-
response = ListProvidersResponse(
300-
providers=sorted(providers, key=lambda p: p.name)
301-
)
342+
providers = self._filter_blocked_models(providers)
343+
providers = sorted(providers, key=lambda p: p.name)
344+
345+
response = ListProvidersResponse(providers=providers)
302346
self.finish(response.json())
303347

304348

0 commit comments

Comments
 (0)