Skip to content

Commit

Permalink
Aggregated request support in RBAC endpoint utils and unified all req…
Browse files Browse the repository at this point in the history
…uest endpoints
  • Loading branch information
stefannica committed Feb 19, 2025
1 parent cbc0ba7 commit 0c70fce
Show file tree
Hide file tree
Showing 29 changed files with 276 additions and 181 deletions.
1 change: 0 additions & 1 deletion src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024

DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"]
REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"]

# API Endpoint paths:
ACTIVATE = "/activate"
Expand Down
148 changes: 97 additions & 51 deletions src/zenml/zen_server/rbac/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
# permissions and limitations under the License.
"""High-level helper functions to write endpoints with RBAC."""

from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel

from zenml.constants import (
REQUIRES_CUSTOM_RESOURCE_REPORTING,
)
from zenml.models import (
BaseFilter,
BaseIdentifiedResponse,
Expand All @@ -29,7 +26,6 @@
Page,
UserScopedRequest,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
)
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.feature_gate.endpoint_utils import (
Expand All @@ -38,32 +34,41 @@
)
from zenml.zen_server.rbac.models import Action, ResourceType
from zenml.zen_server.rbac.utils import (
batch_verify_permissions_for_models,
dehydrate_page,
dehydrate_response_model,
dehydrate_response_model_batch,
get_allowed_resource_ids,
get_resource_type_for_model,
verify_permission,
verify_permission_for_model,
)
from zenml.zen_server.utils import server_config

AnyRequest = TypeVar("AnyRequest", bound=BaseRequest)
AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg]
AnyOtherResponse = TypeVar("AnyOtherResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg]
AnyFilter = TypeVar("AnyFilter", bound=BaseFilter)
AnyUpdate = TypeVar("AnyUpdate", bound=BaseModel)
UUIDOrStr = TypeVar("UUIDOrStr", UUID, Union[UUID, str])


def verify_permissions_and_create_entity(
request_model: AnyRequest,
resource_type: ResourceType,
create_method: Callable[[AnyRequest], AnyResponse],
surrogate_models: Optional[List[AnyOtherResponse]] = None,
skip_entitlements: bool = False,
) -> AnyResponse:
"""Verify permissions and create the entity if authorized.
Args:
request_model: The entity request model.
resource_type: The resource type of the entity to create.
create_method: The method to create the entity.
surrogate_models: Optional list of surrogate models to verify
UPDATE permissions for instead of verifying CREATE permissions for
the request model.
skip_entitlements: Whether to skip the entitlement check and usage
increment.
Returns:
A model of the created entity.
Expand All @@ -73,45 +78,43 @@ def verify_permissions_and_create_entity(
assert auth_context

# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead. This prevents the current user from
# being able to create entities on behalf of other users.
# the current user's ID instead. This is just a precaution, given that
# the SQLZenStore also does this same validation on all request models.
request_model.user = auth_context.user.id

if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = request_model.workspace
if surrogate_models:
batch_verify_permissions_for_models(
models=surrogate_models, action=Action.UPDATE
)
else:
verify_permission_for_model(model=request_model, action=Action.CREATE)

verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)
resource_type = get_resource_type_for_model(request_model)

needs_usage_increment = (
resource_type in server_config().reportable_resources
and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING
)
if needs_usage_increment:
check_entitlement(resource_type)
if resource_type:
needs_usage_increment = (
not skip_entitlements
and resource_type in server_config().reportable_resources
)
if needs_usage_increment:
check_entitlement(resource_type)

created = create_method(request_model)

if needs_usage_increment:
if resource_type and needs_usage_increment:
report_usage(resource_type, resource_id=created.id)

return created
return dehydrate_response_model(created)


def verify_permissions_and_batch_create_entity(
batch: List[AnyRequest],
resource_type: ResourceType,
create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
) -> List[AnyResponse]:
"""Verify permissions and create a batch of entities if authorized.
Args:
batch: The batch to create.
resource_type: The resource type of the entities to create.
create_method: The method to create the entities.
Raises:
Expand All @@ -123,33 +126,73 @@ def verify_permissions_and_batch_create_entity(
auth_context = get_auth_context()
assert auth_context

workspace_ids = set()
resource_types = set()
for request_model in batch:
resource_type = get_resource_type_for_model(request_model)
if resource_type:
resource_types.add(resource_type)

if isinstance(request_model, UserScopedRequest):
# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead.
# Ignore the user field set in the request model, if any, and set it
# to the current user's ID instead. This is just a precaution, given
# that the SQLZenStore also does this same validation on all request
# models.
request_model.user = auth_context.user.id

if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_ids.add(request_model.workspace)
else:
workspace_ids.add(None)

for workspace_id in workspace_ids:
verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)
batch_verify_permissions_for_models(models=batch, action=Action.CREATE)

if resource_type in server_config().reportable_resources:
if resource_types & set(server_config().reportable_resources):
raise RuntimeError(
"Batch requests are currently not possible with usage-tracked features."
"Batch requests are currently not possible with usage-tracked "
"features."
)

created = create_method(batch)
return created
return dehydrate_response_model_batch(created)


def verify_permissions_and_get_or_create_entity(
request_model: AnyRequest,
get_or_create_method: Callable[
[AnyRequest, Optional[Callable[[], None]]], Tuple[AnyResponse, bool]
],
) -> Tuple[AnyResponse, bool]:
"""Verify permissions and create the entity if authorized.
Args:
request_model: The entity request model.
get_or_create_method: The method to get or create the entity.
Returns:
The entity and a boolean indicating whether the entity was created.
"""
if isinstance(request_model, UserScopedRequest):
auth_context = get_auth_context()
assert auth_context

# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead. This is just a precaution, given that
# the SQLZenStore also does this same validation on all request models.
request_model.user = auth_context.user.id

resource_type = get_resource_type_for_model(request_model)
needs_usage_increment = (
resource_type and resource_type in server_config().reportable_resources
)

def _pre_creation_hook() -> None:
verify_permission_for_model(model=request_model, action=Action.CREATE)
if resource_type and needs_usage_increment:
check_entitlement(resource_type=resource_type)

model, created = get_or_create_method(request_model, _pre_creation_hook)

if not created:
verify_permission_for_model(model=model, action=Action.READ)
elif resource_type and needs_usage_increment:
report_usage(resource_type, resource_id=model.id)

return dehydrate_response_model(model), created


def verify_permissions_and_get_entity(
Expand Down Expand Up @@ -198,16 +241,19 @@ def verify_permissions_and_list_entities(
workspace_id: Optional[UUID] = None
if isinstance(filter_model, WorkspaceScopedFilter):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace
workspace_id = filter_model.scope_workspace
if workspace_id is None:
raise ValueError("Workspace ID is required for workspace-scoped resources.")
raise ValueError(
"Workspace ID is required for workspace-scoped resources."
)

elif isinstance(filter_model, FlexibleScopedFilter):
# A flexible scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace
# A flexible scoped request may be scoped to a specific workspace
workspace_id = filter_model.scope_workspace


allowed_ids = get_allowed_resource_ids(resource_type=resource_type)
allowed_ids = get_allowed_resource_ids(
resource_type=resource_type, workspace_id=workspace_id
)
filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, id=allowed_ids
)
Expand Down
4 changes: 1 addition & 3 deletions src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# permissions and limitations under the License.
"""RBAC model classes."""

from typing import Any, Dict, Optional
from typing import Optional
from uuid import UUID

from pydantic import (
BaseModel,
ConfigDict,
ValidationInfo,
field_validator,
model_validator,
)

Expand Down
Loading

0 comments on commit 0c70fce

Please sign in to comment.