From 0c70fce23243156bd5152fd5b84e4153a92ddfb0 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 19 Feb 2025 08:57:16 +0100 Subject: [PATCH] Aggregated request support in RBAC endpoint utils and unified all request endpoints --- src/zenml/constants.py | 1 - src/zenml/zen_server/rbac/endpoint_utils.py | 148 ++++++++++++------ src/zenml/zen_server/rbac/models.py | 4 +- src/zenml/zen_server/rbac/utils.py | 142 +++++++++++++---- .../zen_server/routers/actions_endpoints.py | 1 - .../zen_server/routers/artifact_endpoint.py | 1 - .../routers/artifact_version_endpoints.py | 2 - .../routers/code_repositories_endpoints.py | 1 - .../routers/event_source_endpoints.py | 1 - .../zen_server/routers/flavors_endpoints.py | 1 - .../routers/model_versions_endpoints.py | 21 +-- .../zen_server/routers/models_endpoints.py | 1 - .../routers/pipeline_builds_endpoints.py | 1 - .../routers/pipeline_deployments_endpoints.py | 1 - .../zen_server/routers/pipelines_endpoints.py | 23 +-- .../routers/run_metadata_endpoints.py | 8 +- .../routers/run_templates_endpoints.py | 1 - .../zen_server/routers/runs_endpoints.py | 30 +--- .../zen_server/routers/schedule_endpoints.py | 12 +- .../zen_server/routers/secrets_endpoints.py | 1 - .../routers/service_accounts_endpoints.py | 20 ++- .../routers/service_connectors_endpoints.py | 12 +- .../zen_server/routers/service_endpoints.py | 1 - .../routers/stack_components_endpoints.py | 1 - .../zen_server/routers/stacks_endpoints.py | 2 +- .../zen_server/routers/steps_endpoints.py | 17 +- .../zen_server/routers/tags_endpoints.py | 1 - .../zen_server/routers/triggers_endpoints.py | 1 - .../zen_server/routers/users_endpoints.py | 1 - 29 files changed, 276 insertions(+), 181 deletions(-) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 18929c25865..81b59c8219a 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -331,7 +331,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024 DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"] -REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"] # API Endpoint paths: ACTIVATE = "/activate" diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index d71074bff9b..a03400b8902 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -13,14 +13,11 @@ # permissions and limitations under the License. """High-level helper functions to write endpoints with RBAC.""" -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union from uuid import UUID from pydantic import BaseModel -from zenml.constants import ( - REQUIRES_CUSTOM_RESOURCE_REPORTING, -) from zenml.models import ( BaseFilter, BaseIdentifiedResponse, @@ -29,7 +26,6 @@ Page, UserScopedRequest, WorkspaceScopedFilter, - WorkspaceScopedRequest, ) from zenml.zen_server.auth import get_auth_context from zenml.zen_server.feature_gate.endpoint_utils import ( @@ -38,9 +34,12 @@ ) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + batch_verify_permissions_for_models, dehydrate_page, dehydrate_response_model, + dehydrate_response_model_batch, get_allowed_resource_ids, + get_resource_type_for_model, verify_permission, verify_permission_for_model, ) @@ -48,6 +47,7 @@ AnyRequest = TypeVar("AnyRequest", bound=BaseRequest) AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg] +AnyOtherResponse = TypeVar("AnyOtherResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg] AnyFilter = TypeVar("AnyFilter", bound=BaseFilter) AnyUpdate = TypeVar("AnyUpdate", bound=BaseModel) UUIDOrStr = TypeVar("UUIDOrStr", UUID, Union[UUID, str]) @@ -55,15 +55,20 @@ def verify_permissions_and_create_entity( request_model: AnyRequest, - resource_type: ResourceType, create_method: Callable[[AnyRequest], AnyResponse], + surrogate_models: Optional[List[AnyOtherResponse]] = None, + skip_entitlements: bool = False, ) -> AnyResponse: """Verify permissions and create the entity if authorized. Args: request_model: The entity request model. - resource_type: The resource type of the entity to create. create_method: The method to create the entity. + surrogate_models: Optional list of surrogate models to verify + UPDATE permissions for instead of verifying CREATE permissions for + the request model. + skip_entitlements: Whether to skip the entitlement check and usage + increment. Returns: A model of the created entity. @@ -73,45 +78,43 @@ def verify_permissions_and_create_entity( assert auth_context # Ignore the user field set in the request model, if any, and set it to - # the current user's ID instead. This prevents the current user from - # being able to create entities on behalf of other users. + # the current user's ID instead. This is just a precaution, given that + # the SQLZenStore also does this same validation on all request models. request_model.user = auth_context.user.id - if isinstance(request_model, WorkspaceScopedRequest): - # A workspace scoped request is always scoped to a specific workspace - workspace_id = request_model.workspace + if surrogate_models: + batch_verify_permissions_for_models( + models=surrogate_models, action=Action.UPDATE + ) + else: + verify_permission_for_model(model=request_model, action=Action.CREATE) - verify_permission( - resource_type=resource_type, - action=Action.CREATE, - workspace_id=workspace_id, - ) + resource_type = get_resource_type_for_model(request_model) - needs_usage_increment = ( - resource_type in server_config().reportable_resources - and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING - ) - if needs_usage_increment: - check_entitlement(resource_type) + if resource_type: + needs_usage_increment = ( + not skip_entitlements + and resource_type in server_config().reportable_resources + ) + if needs_usage_increment: + check_entitlement(resource_type) created = create_method(request_model) - if needs_usage_increment: + if resource_type and needs_usage_increment: report_usage(resource_type, resource_id=created.id) - return created + return dehydrate_response_model(created) def verify_permissions_and_batch_create_entity( batch: List[AnyRequest], - resource_type: ResourceType, create_method: Callable[[List[AnyRequest]], List[AnyResponse]], ) -> List[AnyResponse]: """Verify permissions and create a batch of entities if authorized. Args: batch: The batch to create. - resource_type: The resource type of the entities to create. create_method: The method to create the entities. Raises: @@ -123,33 +126,73 @@ def verify_permissions_and_batch_create_entity( auth_context = get_auth_context() assert auth_context - workspace_ids = set() + resource_types = set() for request_model in batch: + resource_type = get_resource_type_for_model(request_model) + if resource_type: + resource_types.add(resource_type) + if isinstance(request_model, UserScopedRequest): - # Ignore the user field set in the request model, if any, and set it to - # the current user's ID instead. + # Ignore the user field set in the request model, if any, and set it + # to the current user's ID instead. This is just a precaution, given + # that the SQLZenStore also does this same validation on all request + # models. request_model.user = auth_context.user.id - if isinstance(request_model, WorkspaceScopedRequest): - # A workspace scoped request is always scoped to a specific workspace - workspace_ids.add(request_model.workspace) - else: - workspace_ids.add(None) - - for workspace_id in workspace_ids: - verify_permission( - resource_type=resource_type, - action=Action.CREATE, - workspace_id=workspace_id, - ) + batch_verify_permissions_for_models(models=batch, action=Action.CREATE) - if resource_type in server_config().reportable_resources: + if resource_types & set(server_config().reportable_resources): raise RuntimeError( - "Batch requests are currently not possible with usage-tracked features." + "Batch requests are currently not possible with usage-tracked " + "features." ) created = create_method(batch) - return created + return dehydrate_response_model_batch(created) + + +def verify_permissions_and_get_or_create_entity( + request_model: AnyRequest, + get_or_create_method: Callable[ + [AnyRequest, Optional[Callable[[], None]]], Tuple[AnyResponse, bool] + ], +) -> Tuple[AnyResponse, bool]: + """Verify permissions and create the entity if authorized. + + Args: + request_model: The entity request model. + get_or_create_method: The method to get or create the entity. + + Returns: + The entity and a boolean indicating whether the entity was created. + """ + if isinstance(request_model, UserScopedRequest): + auth_context = get_auth_context() + assert auth_context + + # Ignore the user field set in the request model, if any, and set it to + # the current user's ID instead. This is just a precaution, given that + # the SQLZenStore also does this same validation on all request models. + request_model.user = auth_context.user.id + + resource_type = get_resource_type_for_model(request_model) + needs_usage_increment = ( + resource_type and resource_type in server_config().reportable_resources + ) + + def _pre_creation_hook() -> None: + verify_permission_for_model(model=request_model, action=Action.CREATE) + if resource_type and needs_usage_increment: + check_entitlement(resource_type=resource_type) + + model, created = get_or_create_method(request_model, _pre_creation_hook) + + if not created: + verify_permission_for_model(model=model, action=Action.READ) + elif resource_type and needs_usage_increment: + report_usage(resource_type, resource_id=model.id) + + return dehydrate_response_model(model), created def verify_permissions_and_get_entity( @@ -198,16 +241,19 @@ def verify_permissions_and_list_entities( workspace_id: Optional[UUID] = None if isinstance(filter_model, WorkspaceScopedFilter): # A workspace scoped request is always scoped to a specific workspace - workspace_id = filter_model.workspace + workspace_id = filter_model.scope_workspace if workspace_id is None: - raise ValueError("Workspace ID is required for workspace-scoped resources.") + raise ValueError( + "Workspace ID is required for workspace-scoped resources." + ) elif isinstance(filter_model, FlexibleScopedFilter): - # A flexible scoped request is always scoped to a specific workspace - workspace_id = filter_model.workspace + # A flexible scoped request may be scoped to a specific workspace + workspace_id = filter_model.scope_workspace - - allowed_ids = get_allowed_resource_ids(resource_type=resource_type) + allowed_ids = get_allowed_resource_ids( + resource_type=resource_type, workspace_id=workspace_id + ) filter_model.configure_rbac( authenticated_user_id=auth_context.user.id, id=allowed_ids ) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 0e437d4142c..62e94a1b8e8 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -13,14 +13,12 @@ # permissions and limitations under the License. """RBAC model classes.""" -from typing import Any, Dict, Optional +from typing import Optional from uuid import UUID from pydantic import ( BaseModel, ConfigDict, - ValidationInfo, - field_validator, model_validator, ) diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index a9447dd9fa4..eeec59725db 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -34,7 +34,7 @@ Page, UserResponse, UserScopedResponse, - FlexibleScopedResponse, + WorkspaceScopedRequest, WorkspaceScopedResponse, ) from zenml.zen_server.auth import get_auth_context @@ -57,24 +57,39 @@ def dehydrate_page(page: Page[AnyResponse]) -> Page[AnyResponse]: Returns: The page with (potentially) dehydrated items. """ + new_items = dehydrate_response_model_batch(page.items) + return page.model_copy(update={"items": new_items}) + + +def dehydrate_response_model_batch( + batch: List[AnyResponse], +) -> List[AnyResponse]: + """Dehydrate all items of a batch. + + Args: + batch: The batch to dehydrate. + + Returns: + The batch with (potentially) dehydrated items. + """ if not server_config().rbac_enabled: - return page + return batch auth_context = get_auth_context() assert auth_context - resource_list = [get_subresources_for_model(item) for item in page.items] + resource_list = [get_subresources_for_model(item) for item in batch] resources = set.union(*resource_list) if resource_list else set() permissions = rbac().check_permissions( user=auth_context.user, resources=resources, action=Action.READ ) - new_items = [ + new_batch = [ dehydrate_response_model(item, permissions=permissions) - for item in page.items + for item in batch ] - return page.model_copy(update={"items": new_items}) + return new_batch def dehydrate_response_model( @@ -203,7 +218,7 @@ def get_permission_denied_model(model: AnyResponse) -> AnyResponse: def batch_verify_permissions_for_models( - models: Sequence[AnyResponse], + models: Sequence[AnyModel], action: Action, ) -> None: """Batch permission verification for models. @@ -231,7 +246,7 @@ def batch_verify_permissions_for_models( batch_verify_permissions(resources=resources, action=action) -def verify_permission_for_model(model: AnyResponse, action: Action) -> None: +def verify_permission_for_model(model: AnyModel, action: Action) -> None: """Verifies if a user has permission to perform an action on a model. Args: @@ -306,20 +321,37 @@ def verify_permission( def get_allowed_resource_ids( resource_type: str, action: Action = Action.READ, + workspace_id: Optional[UUID] = None, ) -> Optional[Set[UUID]]: """Get all resource IDs of a resource type that a user can access. Args: resource_type: The resource type. action: The action the user wants to perform on the resource. + workspace_id: Optional workspace ID to filter the resources by. + Required for workspace scoped resources. Returns: A list of resource IDs or `None` if the user has full access to the all instances of the resource. + + Raises: + ValueError: If the resource type is workspace scoped and no workspace ID + is provided. """ if not server_config().rbac_enabled: return None + if ResourceType(resource_type).is_workspace_scoped and not workspace_id: + raise ValueError( + "Workspace ID is required to list workspace scoped resources." + ) + elif workspace_id: + raise ValueError( + "Workspace ID is not allowed to list resources that are not " + "workspace scoped." + ) + auth_context = get_auth_context() assert auth_context @@ -338,7 +370,7 @@ def get_allowed_resource_ids( return {UUID(id) for id in allowed_ids} -def get_resource_for_model(model: AnyResponse) -> Optional[Resource]: +def get_resource_for_model(model: AnyModel) -> Optional[Resource]: """Get the resource associated with a model object. Args: @@ -355,14 +387,24 @@ def get_resource_for_model(model: AnyResponse) -> Optional[Resource]: workspace_id: Optional[UUID] = None if isinstance(model, WorkspaceScopedResponse): + # A workspace scoped request is always scoped to a specific workspace workspace_id = model.workspace.id + elif isinstance(model, WorkspaceScopedRequest): + # A workspace scoped request is always scoped to a specific workspace + workspace_id = model.workspace - return Resource(type=resource_type, id=model.id, workspace_id=workspace_id) + resource_id: Optional[UUID] = None + if isinstance(model, BaseIdentifiedResponse): + resource_id = model.id + + return Resource( + type=resource_type, id=resource_id, workspace_id=workspace_id + ) def get_surrogate_permission_model_for_model( - model: AnyResponse, action: str -) -> BaseIdentifiedResponse[Any, Any, Any]: + model: BaseModel, action: str +) -> BaseModel: """Get a surrogate permission model for a model. In some cases a different model instead of the original model is used to @@ -390,7 +432,7 @@ def get_surrogate_permission_model_for_model( def get_resource_type_for_model( - model: AnyResponse, + model: AnyModel, ) -> Optional[ResourceType]: """Get the resource type associated with a model object. @@ -402,27 +444,50 @@ def get_resource_type_for_model( is not associated with any resource type. """ from zenml.models import ( + ActionRequest, ActionResponse, + ArtifactRequest, ArtifactResponse, + ArtifactVersionRequest, ArtifactVersionResponse, + CodeRepositoryRequest, CodeRepositoryResponse, + ComponentRequest, ComponentResponse, + EventSourceRequest, EventSourceResponse, + FlavorRequest, FlavorResponse, + ModelRequest, ModelResponse, + ModelVersionRequest, ModelVersionResponse, + PipelineBuildRequest, PipelineBuildResponse, + PipelineDeploymentRequest, PipelineDeploymentResponse, + PipelineRequest, PipelineResponse, + PipelineRunRequest, PipelineRunResponse, + RunMetadataRequest, + RunTemplateRequest, RunTemplateResponse, + SecretRequest, SecretResponse, + ServiceAccountRequest, ServiceAccountResponse, + ServiceConnectorRequest, ServiceConnectorResponse, + ServiceRequest, ServiceResponse, + StackRequest, StackResponse, + TagRequest, TagResponse, + TriggerExecutionRequest, TriggerExecutionResponse, + TriggerRequest, TriggerResponse, ) @@ -430,30 +495,53 @@ def get_resource_type_for_model( Any, ResourceType, ] = { + ActionRequest: ResourceType.ACTION, ActionResponse: ResourceType.ACTION, + ArtifactRequest: ResourceType.ARTIFACT, + ArtifactResponse: ResourceType.ARTIFACT, + ArtifactVersionRequest: ResourceType.ARTIFACT_VERSION, + ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION, + CodeRepositoryRequest: ResourceType.CODE_REPOSITORY, + CodeRepositoryResponse: ResourceType.CODE_REPOSITORY, + ComponentRequest: ResourceType.STACK_COMPONENT, + ComponentResponse: ResourceType.STACK_COMPONENT, + EventSourceRequest: ResourceType.EVENT_SOURCE, EventSourceResponse: ResourceType.EVENT_SOURCE, + FlavorRequest: ResourceType.FLAVOR, FlavorResponse: ResourceType.FLAVOR, - ServiceConnectorResponse: ResourceType.SERVICE_CONNECTOR, - ComponentResponse: ResourceType.STACK_COMPONENT, - StackResponse: ResourceType.STACK, - PipelineResponse: ResourceType.PIPELINE, - CodeRepositoryResponse: ResourceType.CODE_REPOSITORY, - SecretResponse: ResourceType.SECRET, + ModelRequest: ResourceType.MODEL, ModelResponse: ResourceType.MODEL, + ModelVersionRequest: ResourceType.MODEL_VERSION, ModelVersionResponse: ResourceType.MODEL_VERSION, - ArtifactResponse: ResourceType.ARTIFACT, - ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION, - # WorkspaceResponse: ResourceType.WORKSPACE, - # UserResponse: ResourceType.USER, - PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT, + PipelineBuildRequest: ResourceType.PIPELINE_BUILD, PipelineBuildResponse: ResourceType.PIPELINE_BUILD, + PipelineDeploymentRequest: ResourceType.PIPELINE_DEPLOYMENT, + PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT, + PipelineRequest: ResourceType.PIPELINE, + PipelineResponse: ResourceType.PIPELINE, + PipelineRunRequest: ResourceType.PIPELINE_RUN, PipelineRunResponse: ResourceType.PIPELINE_RUN, + RunMetadataRequest: ResourceType.RUN_METADATA, + RunTemplateRequest: ResourceType.RUN_TEMPLATE, RunTemplateResponse: ResourceType.RUN_TEMPLATE, + SecretRequest: ResourceType.SECRET, + SecretResponse: ResourceType.SECRET, + ServiceAccountRequest: ResourceType.SERVICE_ACCOUNT, + ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT, + ServiceConnectorRequest: ResourceType.SERVICE_CONNECTOR, + ServiceConnectorResponse: ResourceType.SERVICE_CONNECTOR, + ServiceRequest: ResourceType.SERVICE, + ServiceResponse: ResourceType.SERVICE, + StackRequest: ResourceType.STACK, + StackResponse: ResourceType.STACK, + TagRequest: ResourceType.TAG, TagResponse: ResourceType.TAG, + TriggerRequest: ResourceType.TRIGGER, TriggerResponse: ResourceType.TRIGGER, + TriggerExecutionRequest: ResourceType.TRIGGER_EXECUTION, TriggerExecutionResponse: ResourceType.TRIGGER_EXECUTION, - ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT, - ServiceResponse: ResourceType.SERVICE, + # WorkspaceResponse: ResourceType.WORKSPACE, + # UserResponse: ResourceType.USER, } return mapping.get(type(model)) @@ -477,7 +565,7 @@ def is_resource_type_workspace_scoped(resource_type: ResourceType) -> bool: ] -def is_owned_by_authenticated_user(model: AnyResponse) -> bool: +def is_owned_by_authenticated_user(model: AnyModel) -> bool: """Returns whether the currently authenticated user owns the model. Args: diff --git a/src/zenml/zen_server/routers/actions_endpoints.py b/src/zenml/zen_server/routers/actions_endpoints.py index 4839f508fb9..e3c0428d603 100644 --- a/src/zenml/zen_server/routers/actions_endpoints.py +++ b/src/zenml/zen_server/routers/actions_endpoints.py @@ -219,7 +219,6 @@ def create_action( return verify_permissions_and_create_entity( request_model=action, - resource_type=ResourceType.ACTION, create_method=action_handler.create_action, ) diff --git a/src/zenml/zen_server/routers/artifact_endpoint.py b/src/zenml/zen_server/routers/artifact_endpoint.py index fa6e1451c0a..cb0a0cae66f 100644 --- a/src/zenml/zen_server/routers/artifact_endpoint.py +++ b/src/zenml/zen_server/routers/artifact_endpoint.py @@ -100,7 +100,6 @@ def create_artifact( """ return verify_permissions_and_create_entity( request_model=artifact, - resource_type=ResourceType.ARTIFACT, create_method=zen_store().create_artifact, ) diff --git a/src/zenml/zen_server/routers/artifact_version_endpoints.py b/src/zenml/zen_server/routers/artifact_version_endpoints.py index eda381c233d..a9946db14f5 100644 --- a/src/zenml/zen_server/routers/artifact_version_endpoints.py +++ b/src/zenml/zen_server/routers/artifact_version_endpoints.py @@ -115,7 +115,6 @@ def create_artifact_version( """ return verify_permissions_and_create_entity( request_model=artifact_version, - resource_type=ResourceType.ARTIFACT_VERSION, create_method=zen_store().create_artifact_version, ) @@ -139,7 +138,6 @@ def batch_create_artifact_version( """ return verify_permissions_and_batch_create_entity( batch=artifact_versions, - resource_type=ResourceType.ARTIFACT_VERSION, create_method=zen_store().batch_create_artifact_versions, ) diff --git a/src/zenml/zen_server/routers/code_repositories_endpoints.py b/src/zenml/zen_server/routers/code_repositories_endpoints.py index 35fa111e84c..d7e2f933e86 100644 --- a/src/zenml/zen_server/routers/code_repositories_endpoints.py +++ b/src/zenml/zen_server/routers/code_repositories_endpoints.py @@ -87,7 +87,6 @@ def create_code_repository( return verify_permissions_and_create_entity( request_model=code_repository, - resource_type=ResourceType.CODE_REPOSITORY, create_method=zen_store().create_code_repository, ) diff --git a/src/zenml/zen_server/routers/event_source_endpoints.py b/src/zenml/zen_server/routers/event_source_endpoints.py index 4c01a1e00cc..100cc520942 100644 --- a/src/zenml/zen_server/routers/event_source_endpoints.py +++ b/src/zenml/zen_server/routers/event_source_endpoints.py @@ -222,7 +222,6 @@ def create_event_source( return verify_permissions_and_create_entity( request_model=event_source, - resource_type=ResourceType.EVENT_SOURCE, create_method=event_source_handler.create_event_source, ) diff --git a/src/zenml/zen_server/routers/flavors_endpoints.py b/src/zenml/zen_server/routers/flavors_endpoints.py index 50a2de4786c..44047142392 100644 --- a/src/zenml/zen_server/routers/flavors_endpoints.py +++ b/src/zenml/zen_server/routers/flavors_endpoints.py @@ -125,7 +125,6 @@ def create_flavor( """ return verify_permissions_and_create_entity( request_model=flavor, - resource_type=ResourceType.FLAVOR, create_method=zen_store().create_flavor, ) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index 8e764a358b9..5270015ba9a 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -118,7 +118,6 @@ def create_model_version( return verify_permissions_and_create_entity( request_model=model_version, - resource_type=ResourceType.MODEL_VERSION, create_method=zen_store().create_model_version, ) @@ -277,12 +276,14 @@ def create_model_version_artifact_link( model_version = zen_store().get_model_version( model_version_artifact_link.model_version ) - verify_permission_for_model(model_version, action=Action.UPDATE) - mv = zen_store().create_model_version_artifact_link( - model_version_artifact_link + return verify_permissions_and_create_entity( + request_model=model_version_artifact_link, + create_method=zen_store().create_model_version_artifact_link, + # Check for UPDATE permissions on the model version instead of the + # model version artifact link + surrogate_models=[model_version], ) - return mv @model_version_artifacts_router.get( @@ -399,12 +400,14 @@ def create_model_version_pipeline_run_link( model_version = zen_store().get_model_version( model_version_pipeline_run_link.model_version, hydrate=False ) - verify_permission_for_model(model_version, action=Action.UPDATE) - mv = zen_store().create_model_version_pipeline_run_link( - model_version_pipeline_run_link + return verify_permissions_and_create_entity( + request_model=model_version_pipeline_run_link, + create_method=zen_store().create_model_version_pipeline_run_link, + # Check for UPDATE permissions on the model version instead of the + # model version pipeline run link + surrogate_models=[model_version], ) - return mv @model_version_pipeline_runs_router.get( diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index cb30ebff8f1..d05308d5ced 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -104,7 +104,6 @@ def create_model( return verify_permissions_and_create_entity( request_model=model, - resource_type=ResourceType.MODEL, create_method=zen_store().create_model, ) diff --git a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py index c1dff7aec07..cca721d98c7 100644 --- a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py @@ -85,7 +85,6 @@ def create_build( return verify_permissions_and_create_entity( request_model=build, - resource_type=ResourceType.PIPELINE_BUILD, create_method=zen_store().create_build, ) diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index ea803a219b0..258d5c60685 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -87,7 +87,6 @@ def create_deployment( return verify_permissions_and_create_entity( request_model=deployment, - resource_type=ResourceType.PIPELINE_DEPLOYMENT, create_method=zen_store().create_deployment, ) diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index 59a3862cb01..d64c3cf102a 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -36,9 +36,7 @@ from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.feature_gate.endpoint_utils import ( - check_entitlement, report_decrement, - report_usage, ) from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_create_entity, @@ -99,29 +97,16 @@ def create_pipeline( pipeline.workspace = workspace.id # We limit pipeline namespaces, not pipeline versions - needs_usage_increment = ( - ResourceType.PIPELINE in server_config().reportable_resources - and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) - == 0 + skip_entitlements = ( + zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) > 0 ) - if needs_usage_increment: - check_entitlement(ResourceType.PIPELINE) - - pipeline_response = verify_permissions_and_create_entity( + return verify_permissions_and_create_entity( request_model=pipeline, - resource_type=ResourceType.PIPELINE, create_method=zen_store().create_pipeline, + skip_entitlements=skip_entitlements, ) - if needs_usage_increment: - report_usage( - resource_type=ResourceType.PIPELINE, - resource_id=pipeline_response.id, - ) - - return pipeline_response - @router.get( "", diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py index 026e7dc6c29..14bc8997cd9 100644 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ b/src/zenml/zen_server/routers/run_metadata_endpoints.py @@ -23,10 +23,10 @@ from zenml.models import RunMetadataRequest from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response -from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.models import Action from zenml.zen_server.rbac.utils import ( batch_verify_permissions_for_models, - verify_permission, + verify_permission_for_model, ) from zenml.zen_server.routers.workspaces_endpoints import ( router as workspace_router, @@ -97,8 +97,6 @@ def create_run_metadata( action=Action.UPDATE, ) - verify_permission( - resource_type=ResourceType.RUN_METADATA, action=Action.CREATE - ) + verify_permission_for_model(model=run_metadata, action=Action.CREATE) zen_store().create_run_metadata(run_metadata) diff --git a/src/zenml/zen_server/routers/run_templates_endpoints.py b/src/zenml/zen_server/routers/run_templates_endpoints.py index 0f1bb52bb15..180efc177f0 100644 --- a/src/zenml/zen_server/routers/run_templates_endpoints.py +++ b/src/zenml/zen_server/routers/run_templates_endpoints.py @@ -93,7 +93,6 @@ def create_run_template( return verify_permissions_and_create_entity( request_model=run_template, - resource_type=ResourceType.RUN_TEMPLATE, create_method=zen_store().create_run_template, ) diff --git a/src/zenml/zen_server/routers/runs_endpoints.py b/src/zenml/zen_server/routers/runs_endpoints.py index 06a81093b5a..f6b0537c8cb 100644 --- a/src/zenml/zen_server/routers/runs_endpoints.py +++ b/src/zenml/zen_server/routers/runs_endpoints.py @@ -40,19 +40,15 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response -from zenml.zen_server.feature_gate.endpoint_utils import ( - check_entitlement, - report_usage, -) from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_delete_entity, verify_permissions_and_get_entity, + verify_permissions_and_get_or_create_entity, verify_permissions_and_list_entities, verify_permissions_and_update_entity, ) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( - verify_permission, verify_permission_for_model, ) from zenml.zen_server.routers.workspaces_endpoints import ( @@ -92,14 +88,13 @@ def get_or_create_pipeline_run( pipeline_run: PipelineRunRequest, workspace_name_or_id: Optional[Union[str, UUID]] = None, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> Tuple[PipelineRunResponse, bool]: """Get or create a pipeline run. Args: pipeline_run: Pipeline run to create. workspace_name_or_id: Optional name or ID of the workspace. - auth_context: Authentication context. Returns: The pipeline run and a boolean indicating whether the run was created @@ -109,25 +104,10 @@ def get_or_create_pipeline_run( workspace = zen_store().get_workspace(workspace_name_or_id) pipeline_run.workspace = workspace.id - pipeline_run.user = auth_context.user.id - - def _pre_creation_hook() -> None: - verify_permission( - resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE - ) - check_entitlement(resource_type=ResourceType.PIPELINE_RUN) - - run, created = zen_store().get_or_create_run( - pipeline_run=pipeline_run, pre_creation_hook=_pre_creation_hook + return verify_permissions_and_get_or_create_entity( + request_model=pipeline_run, + get_or_create_method=zen_store().get_or_create_run, ) - if created: - report_usage( - resource_type=ResourceType.PIPELINE_RUN, resource_id=run.id - ) - else: - verify_permission_for_model(run, action=Action.READ) - - return run, created @router.get( diff --git a/src/zenml/zen_server/routers/schedule_endpoints.py b/src/zenml/zen_server/routers/schedule_endpoints.py index ade186124f7..f538d0f6277 100644 --- a/src/zenml/zen_server/routers/schedule_endpoints.py +++ b/src/zenml/zen_server/routers/schedule_endpoints.py @@ -28,6 +28,9 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, +) from zenml.zen_server.routers.workspaces_endpoints import ( router as workspace_router, ) @@ -78,9 +81,12 @@ def create_schedule( workspace = zen_store().get_workspace(workspace_name_or_id) schedule.workspace = workspace.id - schedule.user = auth_context.user.id - - return zen_store().create_schedule(schedule=schedule) + # NOTE: no RBAC is enforced currently for schedules, but we're + # keeping the RBAC checks here for consistency + return verify_permissions_and_create_entity( + request_model=schedule, + create_method=zen_store().create_schedule, + ) @router.get( diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 95ce19f899b..c379c890160 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -106,7 +106,6 @@ def create_secret( return verify_permissions_and_create_entity( request_model=secret, - resource_type=ResourceType.SECRET, create_method=zen_store().create_secret, ) diff --git a/src/zenml/zen_server/routers/service_accounts_endpoints.py b/src/zenml/zen_server/routers/service_accounts_endpoints.py index 8dedb498cf5..d6be506e0be 100644 --- a/src/zenml/zen_server/routers/service_accounts_endpoints.py +++ b/src/zenml/zen_server/routers/service_accounts_endpoints.py @@ -89,7 +89,6 @@ def create_service_account( """ return verify_permissions_and_create_entity( request_model=service_account, - resource_type=ResourceType.SERVICE_ACCOUNT, create_method=zen_store().create_service_account, ) @@ -233,13 +232,22 @@ def create_api_key( Returns: The created API key. """ + + def create_api_key_wrapper( + api_key: APIKeyRequest, + ) -> APIKeyResponse: + return zen_store().create_api_key( + service_account_id=service_account_id, + api_key=api_key, + ) + service_account = zen_store().get_service_account(service_account_id) - verify_permission_for_model(service_account, action=Action.UPDATE) - created_api_key = zen_store().create_api_key( - service_account_id=service_account_id, - api_key=api_key, + + return verify_permissions_and_create_entity( + request_model=api_key, + create_method=create_api_key_wrapper, + surrogate_models=[service_account], ) - return created_api_key @router.get( diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 68098dd6b22..53e0c37b825 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -45,6 +45,7 @@ from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, verify_permissions_and_delete_entity, verify_permissions_and_list_entities, verify_permissions_and_update_entity, @@ -113,12 +114,11 @@ def create_service_connector( workspace = zen_store().get_workspace(workspace_name_or_id) connector.workspace = workspace.id - verify_permission( - resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE + return verify_permissions_and_create_entity( + request_model=connector, + create_method=zen_store().create_service_connector, ) - return zen_store().create_service_connector(connector) - @router.get( "", @@ -318,9 +318,7 @@ def validate_and_verify_service_connector_config( The list of resources that the service connector configuration has access to. """ - verify_permission( - resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE - ) + verify_permission_for_model(model=connector, action=Action.CREATE) return zen_store().verify_service_connector_config( service_connector=connector, diff --git a/src/zenml/zen_server/routers/service_endpoints.py b/src/zenml/zen_server/routers/service_endpoints.py index 281f3afea22..419cfcc08ce 100644 --- a/src/zenml/zen_server/routers/service_endpoints.py +++ b/src/zenml/zen_server/routers/service_endpoints.py @@ -88,7 +88,6 @@ def create_service( return verify_permissions_and_create_entity( request_model=service, create_method=zen_store().create_service, - resource_type=ResourceType.SERVICE, ) diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index 3370338b9c1..c430a7d68d3 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -112,7 +112,6 @@ def create_stack_component( return verify_permissions_and_create_entity( request_model=component, - resource_type=ResourceType.STACK_COMPONENT, create_method=zen_store().create_stack_component, ) diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 3ff0082dde3..70a14e6fbb8 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -129,7 +129,7 @@ def create_stack( ) # Check the stack creation - verify_permission(resource_type=ResourceType.STACK, action=Action.CREATE) + verify_permission_for_model(model=stack, action=Action.CREATE) return zen_store().create_stack(stack) diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index e4c8795f387..e7fc946fbf7 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -37,6 +37,9 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, +) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( dehydrate_page, @@ -104,24 +107,24 @@ def list_run_steps( @handle_exceptions def create_run_step( step: StepRunRequest, - auth_context: AuthContext = Security(authorize), + _: AuthContext = Security(authorize), ) -> StepRunResponse: """Create a run step. Args: step: The run step to create. - auth_context: Authentication context. + _: Authentication context. Returns: The created run step. """ - step.user = auth_context.user.id - pipeline_run = zen_store().get_run(step.pipeline_run_id) - verify_permission_for_model(pipeline_run, action=Action.UPDATE) - step_response = zen_store().create_run_step(step_run=step) - return dehydrate_response_model(step_response) + return verify_permissions_and_create_entity( + request_model=step, + create_method=zen_store().create_run_step, + surrogate_models=[pipeline_run], + ) @router.get( diff --git a/src/zenml/zen_server/routers/tags_endpoints.py b/src/zenml/zen_server/routers/tags_endpoints.py index 18626ecd18e..54501ae0379 100644 --- a/src/zenml/zen_server/routers/tags_endpoints.py +++ b/src/zenml/zen_server/routers/tags_endpoints.py @@ -77,7 +77,6 @@ def create_tag( """ return verify_permissions_and_create_entity( request_model=tag, - resource_type=ResourceType.TAG, create_method=zen_store().create_tag, ) diff --git a/src/zenml/zen_server/routers/triggers_endpoints.py b/src/zenml/zen_server/routers/triggers_endpoints.py index 82a0e0a0f9d..89016b32f45 100644 --- a/src/zenml/zen_server/routers/triggers_endpoints.py +++ b/src/zenml/zen_server/routers/triggers_endpoints.py @@ -162,7 +162,6 @@ def create_trigger( return verify_permissions_and_create_entity( request_model=trigger, - resource_type=ResourceType.TRIGGER, create_method=zen_store().create_trigger, ) diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index a8086800dd2..9c445bf8907 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -174,7 +174,6 @@ def create_user( # new_user = verify_permissions_and_create_entity( # request_model=user, - # resource_type=ResourceType.USER, # create_method=zen_store().create_user, # ) new_user = zen_store().create_user(user)