Skip to content

Commit d720744

Browse files
Litellm dev 02 06 2025 p3 (#8343)
* feat(handle_jwt.py): initial commit to allow scope based model access * feat(handle_jwt.py): allow model access based on token scopes allow admin to control model access from IDP * test(test_jwt.py): add unit testing for scope based model access * docs(token_auth.md): add scope based model access to docs * docs(token_auth.md): update docs * docs(token_auth.md): update docs * build: add gemini commercial rate limits * fix: fix linting error
1 parent f87ab25 commit d720744

File tree

6 files changed

+238
-7
lines changed

6 files changed

+238
-7
lines changed

Diff for: docs/my-website/docs/proxy/token_auth.md

+65-1
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,68 @@ Supported internal roles:
370370
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
371371
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
372372
373-
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
373+
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
374+
375+
## [BETA] Control Model Access with Scopes
376+
377+
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
378+
379+
### 1. Setup config.yaml with scope mappings.
380+
381+
382+
```yaml
383+
model_list:
384+
- model_name: anthropic-claude
385+
litellm_params:
386+
model: anthropic/claude-3-5-sonnet
387+
api_key: os.environ/ANTHROPIC_API_KEY
388+
- model_name: gpt-3.5-turbo-testing
389+
litellm_params:
390+
model: gpt-3.5-turbo
391+
api_key: os.environ/OPENAI_API_KEY
392+
393+
general_settings:
394+
enable_jwt_auth: True
395+
litellm_jwtauth:
396+
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
397+
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
398+
scope_mappings:
399+
- scope: litellm.api.consumer
400+
models: ["anthropic-claude"]
401+
- scope: litellm.api.gpt_3_5_turbo
402+
models: ["gpt-3.5-turbo-testing"]
403+
enforce_scope_based_access: true # 👈 enforce scope-based access control
404+
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
405+
```
406+
407+
#### Scope Mapping Spec
408+
409+
- `scope`: The scope to be used for the JWT token.
410+
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
411+
412+
### 2. Create a JWT with the correct scopes.
413+
414+
Expected Token:
415+
416+
```
417+
{
418+
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
419+
}
420+
```
421+
422+
### 3. Test the flow.
423+
424+
```bash
425+
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
426+
-H 'Content-Type: application/json' \
427+
-H 'Authorization: Bearer eyJhbGci...' \
428+
-d '{
429+
"model": "gpt-3.5-turbo-testing",
430+
"messages": [
431+
{
432+
"role": "user",
433+
"content": "Hey, how'\''s it going 1234?"
434+
}
435+
]
436+
}'
437+
```

Diff for: litellm/proxy/_new_secret_config.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,17 @@ model_list:
2727
model: openai/fake
2828
api_key: fake-key
2929
api_base: https://exampleopenaiendpoint-production.up.railway.app/
30+
31+
general_settings:
32+
enable_jwt_auth: True
33+
litellm_jwtauth:
34+
team_id_jwt_field: "client_id"
35+
team_id_upsert: true
36+
scope_mappings:
37+
- scope: litellm.api.consumer
38+
models: ["anthropic-claude"]
39+
routes: ["/v1/chat/completions"]
40+
- scope: litellm.api.gpt_3_5_turbo
41+
models: ["gpt-3.5-turbo-testing"]
42+
enforce_scope_based_access: true
43+
enforce_rbac: true

Diff for: litellm/proxy/_types.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,13 @@ def set_model_info(cls, values):
10401040
"model_max_budget",
10411041
"model_aliases",
10421042
]
1043+
1044+
if (
1045+
isinstance(values.get("members_with_roles"), dict)
1046+
and not values["members_with_roles"]
1047+
):
1048+
values["members_with_roles"] = []
1049+
10431050
for field in dict_fields:
10441051
value = values.get(field)
10451052
if value is not None and isinstance(value, str):
@@ -2279,11 +2286,14 @@ class ClientSideFallbackModel(TypedDict, total=False):
22792286
]
22802287

22812288

2282-
class RoleBasedPermissions(LiteLLMPydanticObjectBase):
2283-
role: RBAC_ROLES
2289+
class OIDCPermissions(LiteLLMPydanticObjectBase):
22842290
models: Optional[List[str]] = None
22852291
routes: Optional[List[str]] = None
22862292

2293+
2294+
class RoleBasedPermissions(OIDCPermissions):
2295+
role: RBAC_ROLES
2296+
22872297
model_config = {
22882298
"extra": "forbid",
22892299
}
@@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel):
22942304
internal_role: RBAC_ROLES
22952305

22962306

2307+
class ScopeMapping(OIDCPermissions):
2308+
scope: str
2309+
2310+
model_config = {
2311+
"extra": "forbid",
2312+
}
2313+
2314+
22972315
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
22982316
"""
22992317
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
@@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
23232341
"info_routes",
23242342
]
23252343
team_id_jwt_field: Optional[str] = None
2344+
team_id_upsert: bool = False
23262345
team_ids_jwt_field: Optional[str] = None
23272346
upsert_sso_user_to_team: bool = False
23282347
team_allowed_routes: List[
@@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
23512370
object_id_jwt_field: Optional[str] = (
23522371
None # can be either user / team, inferred from the role mapping
23532372
)
2373+
scope_mappings: Optional[List[ScopeMapping]] = None
2374+
enforce_scope_based_access: bool = False
23542375

23552376
def __init__(self, **kwargs: Any) -> None:
23562377
# get the attribute names for this Pydantic model
@@ -2361,6 +2382,8 @@ def __init__(self, **kwargs: Any) -> None:
23612382
user_allowed_roles = kwargs.get("user_allowed_roles")
23622383
object_id_jwt_field = kwargs.get("object_id_jwt_field")
23632384
role_mappings = kwargs.get("role_mappings")
2385+
scope_mappings = kwargs.get("scope_mappings")
2386+
enforce_scope_based_access = kwargs.get("enforce_scope_based_access")
23642387

23652388
if invalid_keys:
23662389
raise ValueError(
@@ -2378,4 +2401,9 @@ def __init__(self, **kwargs: Any) -> None:
23782401
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
23792402
)
23802403

2404+
if scope_mappings is not None and not enforce_scope_based_access:
2405+
raise ValueError(
2406+
"scope_mappings must be set if enforce_scope_based_access is true."
2407+
)
2408+
23812409
super().__init__(**kwargs)

Diff for: litellm/proxy/auth/auth_checks.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -655,11 +655,20 @@ async def _delete_cache_key_object(
655655

656656

657657
@log_db_metrics
658-
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
659-
return await prisma_client.db.litellm_teamtable.find_unique(
658+
async def _get_team_db_check(
659+
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
660+
):
661+
response = await prisma_client.db.litellm_teamtable.find_unique(
660662
where={"team_id": team_id}
661663
)
662664

665+
if response is None and team_id_upsert:
666+
response = await prisma_client.db.litellm_teamtable.create(
667+
data={"team_id": team_id}
668+
)
669+
670+
return response
671+
663672

664673
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
665674
return await prisma_client.db.litellm_teamtable.find_unique(
@@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache(
675684
db_cache_expiry: int,
676685
proxy_logging_obj: Optional[ProxyLogging],
677686
key: str,
687+
team_id_upsert: Optional[bool] = None,
678688
) -> LiteLLM_TeamTableCachedObj:
679689
db_access_time_key = key
680690
should_check_db = _should_check_db(
@@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache(
684694
)
685695
if should_check_db:
686696
response = await _get_team_db_check(
687-
team_id=team_id, prisma_client=prisma_client
697+
team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
688698
)
689699
else:
690700
response = None
@@ -752,6 +762,7 @@ async def get_team_object(
752762
proxy_logging_obj: Optional[ProxyLogging] = None,
753763
check_cache_only: Optional[bool] = None,
754764
check_db_only: Optional[bool] = None,
765+
team_id_upsert: Optional[bool] = None,
755766
) -> LiteLLM_TeamTableCachedObj:
756767
"""
757768
- Check if team id in proxy Team Table
@@ -795,6 +806,7 @@ async def get_team_object(
795806
last_db_access_time=last_db_access_time,
796807
db_cache_expiry=db_cache_expiry,
797808
key=key,
809+
team_id_upsert=team_id_upsert,
798810
)
799811
except Exception:
800812
raise Exception(

Diff for: litellm/proxy/auth/handle_jwt.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
LiteLLM_TeamTable,
3131
LiteLLM_UserTable,
3232
LitellmUserRoles,
33+
ScopeMapping,
3334
Span,
3435
)
3536
from litellm.proxy.utils import PrismaClient, ProxyLogging
@@ -318,7 +319,7 @@ def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]
318319
org_id = default_value
319320
return org_id
320321

321-
def get_scopes(self, token: dict) -> list:
322+
def get_scopes(self, token: dict) -> List[str]:
322323
try:
323324
if isinstance(token["scope"], str):
324325
# Assuming the scopes are stored in 'scope' claim and are space-separated
@@ -543,6 +544,40 @@ def can_rbac_role_call_model(
543544

544545
return True
545546

547+
@staticmethod
548+
def check_scope_based_access(
549+
scope_mappings: List[ScopeMapping],
550+
scopes: List[str],
551+
request_data: dict,
552+
general_settings: dict,
553+
) -> None:
554+
"""
555+
Check if scope allows access to the requested model
556+
"""
557+
if not scope_mappings:
558+
return None
559+
560+
allowed_models = []
561+
for sm in scope_mappings:
562+
if sm.scope in scopes and sm.models:
563+
allowed_models.extend(sm.models)
564+
565+
requested_model = request_data.get("model")
566+
567+
if not requested_model:
568+
return None
569+
570+
if requested_model not in allowed_models:
571+
raise HTTPException(
572+
status_code=403,
573+
detail={
574+
"error": "model={} not allowed. Allowed_models={}".format(
575+
requested_model, allowed_models
576+
)
577+
},
578+
)
579+
return None
580+
546581
@staticmethod
547582
async def check_rbac_role(
548583
jwt_handler: JWTHandler,
@@ -636,6 +671,7 @@ async def find_and_validate_specific_team_id(
636671
user_api_key_cache=user_api_key_cache,
637672
parent_otel_span=parent_otel_span,
638673
proxy_logging_obj=proxy_logging_obj,
674+
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
639675
)
640676

641677
return individual_team_id, team_object
@@ -829,6 +865,19 @@ async def auth_builder(
829865
rbac_role,
830866
)
831867

868+
# Check Scope Based Access
869+
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
870+
if (
871+
jwt_handler.litellm_jwtauth.enforce_scope_based_access
872+
and jwt_handler.litellm_jwtauth.scope_mappings
873+
):
874+
JWTAuthManager.check_scope_based_access(
875+
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
876+
scopes=scopes,
877+
request_data=request_data,
878+
general_settings=general_settings,
879+
)
880+
832881
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
833882

834883
# Get basic user info

Diff for: tests/proxy_unit_tests/test_jwt.py

+64
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route():
11831183
},
11841184
route="/v1/embeddings",
11851185
)
1186+
1187+
1188+
@pytest.mark.parametrize(
1189+
"requested_model, should_work",
1190+
[
1191+
("gpt-3.5-turbo-testing", True),
1192+
("gpt-4o", False),
1193+
],
1194+
)
1195+
def test_check_scope_based_access(requested_model, should_work):
1196+
from litellm.proxy.auth.handle_jwt import JWTAuthManager
1197+
from litellm.proxy._types import ScopeMapping
1198+
1199+
args = {
1200+
"scope_mappings": [
1201+
ScopeMapping(
1202+
models=["anthropic-claude"],
1203+
routes=["/v1/chat/completions"],
1204+
scope="litellm.api.consumer",
1205+
),
1206+
ScopeMapping(
1207+
models=["gpt-3.5-turbo-testing"],
1208+
routes=None,
1209+
scope="litellm.api.gpt_3_5_turbo",
1210+
),
1211+
],
1212+
"scopes": [
1213+
"profile",
1214+
"groups-scope",
1215+
"email",
1216+
"litellm.api.gpt_3_5_turbo",
1217+
"litellm.api.consumer",
1218+
],
1219+
"request_data": {
1220+
"model": requested_model,
1221+
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}],
1222+
},
1223+
"general_settings": {
1224+
"enable_jwt_auth": True,
1225+
"litellm_jwtauth": {
1226+
"team_id_jwt_field": "client_id",
1227+
"team_id_upsert": True,
1228+
"scope_mappings": [
1229+
{
1230+
"scope": "litellm.api.consumer",
1231+
"models": ["anthropic-claude"],
1232+
"routes": ["/v1/chat/completions"],
1233+
},
1234+
{
1235+
"scope": "litellm.api.gpt_3_5_turbo",
1236+
"models": ["gpt-3.5-turbo-testing"],
1237+
},
1238+
],
1239+
"enforce_scope_based_access": True,
1240+
"enforce_rbac": True,
1241+
},
1242+
},
1243+
}
1244+
1245+
if should_work:
1246+
JWTAuthManager.check_scope_based_access(**args)
1247+
else:
1248+
with pytest.raises(HTTPException):
1249+
JWTAuthManager.check_scope_based_access(**args)

0 commit comments

Comments
 (0)