3
3
import os
4
4
import shutil
5
5
import time
6
- from typing import Optional , Union
6
+ from typing import List , Optional , Union
7
7
8
8
from deepmerge import always_merger as Merger
9
9
from jsonschema import Draft202012Validator as Validator
12
12
AnyProvider ,
13
13
EmProvidersDict ,
14
14
LmProvidersDict ,
15
- ProviderRestrictions ,
16
15
get_em_provider ,
17
16
get_lm_provider ,
18
- is_provider_allowed ,
19
17
)
20
18
from jupyter_core .paths import jupyter_data_dir
21
19
from traitlets import Integer , Unicode
@@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
57
55
pass
58
56
59
57
58
+ class BlockedModelError (Exception ):
59
+ pass
60
+
61
+
60
62
def _validate_provider_authn (config : GlobalConfig , provider : AnyProvider ):
61
63
# TODO: handle non-env auth strategies
62
64
if not provider .auth_strategy or provider .auth_strategy .type != "env" :
@@ -99,27 +101,34 @@ def __init__(
99
101
log : Logger ,
100
102
lm_providers : LmProvidersDict ,
101
103
em_providers : EmProvidersDict ,
102
- restrictions : ProviderRestrictions ,
104
+ allowed_providers : Optional [List [str ]],
105
+ blocked_providers : Optional [List [str ]],
106
+ allowed_models : Optional [List [str ]],
107
+ blocked_models : Optional [List [str ]],
103
108
* args ,
104
109
** kwargs ,
105
110
):
106
111
super ().__init__ (* args , ** kwargs )
107
112
self .log = log
108
- """List of LM providers."""
113
+
109
114
self ._lm_providers = lm_providers
110
- """List of EM providers."""
115
+ """List of LM providers."""
111
116
self ._em_providers = em_providers
112
- """Provider restrictions."""
113
- self ._restrictions = restrictions
117
+ """List of EM providers."""
118
+
119
+ self ._allowed_providers = allowed_providers
120
+ self ._blocked_providers = blocked_providers
121
+ self ._allowed_models = allowed_models
122
+ self ._blocked_models = blocked_models
114
123
124
+ self ._last_read : Optional [int ] = None
115
125
"""When the server last read the config file. If the file was not
116
126
modified after this time, then we can return the cached
117
127
`self._config`."""
118
- self ._last_read : Optional [int ] = None
119
128
129
+ self ._config : Optional [GlobalConfig ] = None
120
130
"""In-memory cache of the `GlobalConfig` object parsed from the config
121
131
file."""
122
- self ._config : Optional [GlobalConfig ] = None
123
132
124
133
self ._init_config_schema ()
125
134
self ._init_validator ()
@@ -140,6 +149,26 @@ def _init_config(self):
140
149
if os .path .exists (self .config_path ):
141
150
with open (self .config_path , encoding = "utf-8" ) as f :
142
151
config = GlobalConfig (** json .loads (f .read ()))
152
+ lm_id = config .model_provider_id
153
+ em_id = config .embeddings_provider_id
154
+
155
+ # if the currently selected language or embedding model are
156
+ # forbidden, set them to `None` and log a warning.
157
+ if lm_id is not None and not self ._validate_model (
158
+ lm_id , raise_exc = False
159
+ ):
160
+ self .log .warning (
161
+ f"Language model { lm_id } is forbidden by current allow/blocklists. Setting to None."
162
+ )
163
+ config .model_provider_id = None
164
+ if em_id is not None and not self ._validate_model (
165
+ em_id , raise_exc = False
166
+ ):
167
+ self .log .warning (
168
+ f"Embedding model { em_id } is forbidden by current allow/blocklists. Setting to None."
169
+ )
170
+ config .embeddings_provider_id = None
171
+
143
172
# re-write to the file to validate the config and apply any
144
173
# updates to the config file immediately
145
174
self ._write_config (config )
@@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
181
210
_ , lm_provider = get_lm_provider (
182
211
config .model_provider_id , self ._lm_providers
183
212
)
184
- # do not check config for blocked providers
185
- if not is_provider_allowed (config .model_provider_id , self ._restrictions ):
186
- assert not lm_provider
187
- return
213
+
214
+ # verify model is declared by some provider
188
215
if not lm_provider :
189
216
raise ValueError (
190
217
f"No language model is associated with '{ config .model_provider_id } '."
191
218
)
219
+
220
+ # verify model is not blocked
221
+ self ._validate_model (config .model_provider_id )
222
+
223
+ # verify model is authenticated
192
224
_validate_provider_authn (config , lm_provider )
193
225
194
226
# validate embedding model config
195
227
if config .embeddings_provider_id :
196
228
_ , em_provider = get_em_provider (
197
229
config .embeddings_provider_id , self ._em_providers
198
230
)
199
- # do not check config for blocked providers
200
- if not is_provider_allowed (
201
- config .embeddings_provider_id , self ._restrictions
202
- ):
203
- assert not em_provider
204
- return
231
+
232
+ # verify model is declared by some provider
205
233
if not em_provider :
206
234
raise ValueError (
207
235
f"No embedding model is associated with '{ config .embeddings_provider_id } '."
208
236
)
237
+
238
+ # verify model is not blocked
239
+ self ._validate_model (config .embeddings_provider_id )
240
+
241
+ # verify model is authenticated
209
242
_validate_provider_authn (config , em_provider )
210
243
244
+ def _validate_model (self , model_id : str , raise_exc = True ):
245
+ """
246
+ Validates a model against the set of allow/blocklists specified by the
247
+ traitlets configuration, returning `True` if the model is allowed, and
248
+ raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
249
+ function returns `False` if the model is not allowed.
250
+ """
251
+
252
+ assert model_id is not None
253
+ components = model_id .split (":" , 1 )
254
+ assert len (components ) == 2
255
+ provider_id , _ = components
256
+
257
+ try :
258
+ if self ._allowed_providers and provider_id not in self ._allowed_providers :
259
+ raise BlockedModelError (
260
+ "Model provider not included in the provider allowlist."
261
+ )
262
+
263
+ if self ._blocked_providers and provider_id in self ._blocked_providers :
264
+ raise BlockedModelError (
265
+ "Model provider included in the provider blocklist."
266
+ )
267
+
268
+ if self ._allowed_models and model_id not in self ._allowed_models :
269
+ raise BlockedModelError ("Model not included in the model allowlist." )
270
+
271
+ if self ._blocked_models and model_id in self ._blocked_models :
272
+ raise BlockedModelError ("Model included in the model blocklist." )
273
+ except BlockedModelError as e :
274
+ if raise_exc :
275
+ raise e
276
+ else :
277
+ return False
278
+
279
+ return True
280
+
211
281
def _write_config (self , new_config : GlobalConfig ):
212
282
"""Updates configuration and persists it to disk. This accepts a
213
283
complete `GlobalConfig` object, and should not be called publicly."""
0 commit comments