Skip to content

Commit

Permalink
Implement RBAC scoping for workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Feb 18, 2025
1 parent e3efd14 commit a6fc966
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 10 deletions.
8 changes: 8 additions & 0 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
BaseDatedResponseBody,
)
from zenml.models.v2.base.scoped import (
FlexibleScopedFilter,
FlexibleScopedRequest,
FlexibleScopedResponse,
FlexibleScopedUpdate,
TaggableFilter,
UserScopedRequest,
UserScopedFilter,
Expand Down Expand Up @@ -490,6 +494,10 @@
"BaseDatedResponseBody",
"BaseZenModel",
"BasePluginFlavorResponse",
"FlexibleScopedFilter",
"FlexibleScopedRequest",
"FlexibleScopedResponse",
"FlexibleScopedUpdate",
"UserScopedRequest",
"UserScopedFilter",
"UserScopedResponse",
Expand Down
47 changes: 43 additions & 4 deletions src/zenml/zen_server/rbac/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""High-level helper functions to write endpoints with RBAC."""

from typing import Any, Callable, List, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -25,8 +25,11 @@
BaseFilter,
BaseIdentifiedResponse,
BaseRequest,
FlexibleScopedFilter,
Page,
UserScopedRequest,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
)
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.feature_gate.endpoint_utils import (
Expand Down Expand Up @@ -70,10 +73,19 @@ def verify_permissions_and_create_entity(
assert auth_context

# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead.
# the current user's ID instead. This prevents the current user from
# being able to create entities on behalf of other users.
request_model.user = auth_context.user.id

verify_permission(resource_type=resource_type, action=Action.CREATE)
if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = request_model.workspace

verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)

needs_usage_increment = (
resource_type in server_config().reportable_resources
Expand Down Expand Up @@ -111,13 +123,25 @@ def verify_permissions_and_batch_create_entity(
auth_context = get_auth_context()
assert auth_context

workspace_ids = set()
for request_model in batch:
if isinstance(request_model, UserScopedRequest):
# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead.
request_model.user = auth_context.user.id

verify_permission(resource_type=resource_type, action=Action.CREATE)
if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_ids.add(request_model.workspace)
else:
workspace_ids.add(None)

for workspace_id in workspace_ids:
verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)

if resource_type in server_config().reportable_resources:
raise RuntimeError(
Expand Down Expand Up @@ -164,10 +188,25 @@ def verify_permissions_and_list_entities(
Returns:
A page of entity models.
Raises:
ValueError: If the workspace ID is not set for workspace-scoped resources.
"""
auth_context = get_auth_context()
assert auth_context

workspace_id: Optional[UUID] = None
if isinstance(filter_model, WorkspaceScopedFilter):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace
if workspace_id is None:
raise ValueError("Workspace ID is required for workspace-scoped resources.")

elif isinstance(filter_model, FlexibleScopedFilter):
# A flexible scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace


allowed_ids = get_allowed_resource_ids(resource_type=resource_type)
filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, id=allowed_ids
Expand Down
91 changes: 90 additions & 1 deletion src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional
from uuid import UUID

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator

from zenml.utils.enum_utils import StrEnum

Expand Down Expand Up @@ -73,12 +73,62 @@ class ResourceType(StrEnum):
# USER = "user"
# WORKSPACE = "workspace"

@classmethod
def is_flexible_scoped(cls, resource_type: "ResourceType") -> bool:
"""Check if a resource type may flexibly be scoped to a workspace.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type may flexibly be scoped to a workspace.
"""
return resource_type in [
cls.FLAVOR,
cls.SECRET,
cls.SERVICE_CONNECTOR,
cls.STACK,
cls.STACK_COMPONENT,
]

@classmethod
def is_workspace_scoped(cls, resource_type: "ResourceType") -> bool:
"""Check if a resource type is workspace scoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is workspace scoped.
"""
return not cls.is_flexible_scoped(
resource_type
) and not cls.is_unscoped(resource_type)

@classmethod
def is_unscoped(cls, resource_type: "ResourceType") -> bool:
"""Check if a resource type is unscoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is unscoped.
"""
return resource_type in [
cls.SERVICE_ACCOUNT,
# Deactivated for now
# cls.USER,
# cls.WORKSPACE,
]


class Resource(BaseModel):
"""RBAC resource model."""

type: str
id: Optional[UUID] = None
workspace_id: Optional[UUID] = None

def __str__(self) -> str:
"""Convert to a string.
Expand All @@ -87,9 +137,48 @@ def __str__(self) -> str:
Resource string representation.
"""
representation = self.type
if self.workspace_id:
representation += f"/workspace/{self.workspace_id}"
if self.id:
representation += f"/{self.id}"

return representation

@field_validator("workspace_id")
@classmethod
def validate_workspace_id(
cls, workspace_id: Optional[UUID], info: ValidationInfo
) -> Optional[UUID]:
"""Validate that workspace_id is set in combination with the correct resource types.
Args:
workspace_id: The workspace ID to validate.
info: The validation info containing the model data.
Returns:
The validated workspace ID.
Raises:
ValueError: If workspace_id is not set for a workspace-scoped
resource or set for an unscoped resource.
"""
resource_type = ResourceType(info.data.get("type"))

if (
ResourceType.is_workspace_scoped(resource_type)
and not workspace_id
):
raise ValueError(
"workspace_id must be set for workspace-scoped resource type "
f"{resource_type}"
)

if ResourceType.is_unscoped(resource_type) and workspace_id:
raise ValueError(
"workspace_id must not be set for unscoped resource type "
f"{resource_type}"
)

return workspace_id

model_config = ConfigDict(frozen=True)
33 changes: 31 additions & 2 deletions src/zenml/zen_server/rbac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Page,
UserResponse,
UserScopedResponse,
FlexibleScopedResponse,
WorkspaceScopedResponse,
)
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
Expand Down Expand Up @@ -283,6 +285,7 @@ def verify_permission(
resource_type: str,
action: Action,
resource_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
) -> None:
"""Verifies if a user has permission to perform an action on a resource.
Expand All @@ -291,8 +294,12 @@ def verify_permission(
action on.
action: The action the user wants to perform.
resource_id: ID of the resource the user wants to perform the action on.
workspace_id: ID of the workspace the user wants to perform the action
on. Only used for workspace scoped resources.
"""
resource = Resource(type=resource_type, id=resource_id)
resource = Resource(
type=resource_type, id=resource_id, workspace_id=workspace_id
)
batch_verify_permissions(resources={resource}, action=action)


Expand Down Expand Up @@ -346,7 +353,11 @@ def get_resource_for_model(model: AnyResponse) -> Optional[Resource]:
# This model is not tied to any RBAC resource type
return None

return Resource(type=resource_type, id=model.id)
workspace_id: Optional[UUID] = None
if isinstance(model, WorkspaceScopedResponse):
workspace_id = model.workspace.id

return Resource(type=resource_type, id=model.id, workspace_id=workspace_id)


def get_surrogate_permission_model_for_model(
Expand Down Expand Up @@ -448,6 +459,24 @@ def get_resource_type_for_model(
return mapping.get(type(model))


def is_resource_type_workspace_scoped(resource_type: ResourceType) -> bool:
"""Check if a resource type is workspace scoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is workspace scoped.
"""
return resource_type in [
ResourceType.STACK,
ResourceType.PIPELINE,
ResourceType.CODE_REPOSITORY,
ResourceType.SECRET,
ResourceType.MODEL,
]


def is_owned_by_authenticated_user(model: AnyResponse) -> bool:
"""Returns whether the currently authenticated user owns the model.
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11372,7 +11372,7 @@ def _attach_tags_to_resource_new(
resource_type = self._get_taggable_resource_type(
resource=resource
)
self.create_tag_resource(
self._create_tag_resource(
TagResourceRequest(
tag_id=tag.id,
resource_id=resource.id,
Expand Down Expand Up @@ -11401,7 +11401,7 @@ def _attach_tags_to_resource(
except KeyError:
tag = self.create_tag(TagRequest(name=tag_name))
try:
self.create_tag_resource(
self._create_tag_resource(
TagResourceRequest(
tag_id=tag.id,
resource_id=resource_id,
Expand Down Expand Up @@ -11583,7 +11583,7 @@ def update_tag(
# Tags <> resources
####################

def create_tag_resource(
def _create_tag_resource(
self, tag_resource: TagResourceRequest
) -> TagResourceResponse:
"""Creates a new tag resource relationship.
Expand Down

0 comments on commit a6fc966

Please sign in to comment.