Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions application/backend/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import zipfile
from typing import Annotated

from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.openapi.models import Example
from fastapi.responses import StreamingResponse

from app.api.dependencies import get_model_service, get_project
from app.api.schemas import ModelView, ProjectView
from app.api.validators import ModelID
from app.api.validators import DatasetRevisionID, ModelID
from app.services import ModelService, ResourceInUseError, ResourceNotFoundError

router = APIRouter(prefix="/api/projects/{project_id}/models", tags=["Models"])
Expand All @@ -29,10 +29,17 @@
def list_models(
project: Annotated[ProjectView, Depends(get_project)],
model_service: Annotated[ModelService, Depends(get_model_service)],
dataset_revision_id: Annotated[
DatasetRevisionID | None,
Query(description="Dataset revision id for optional filtering"),
] = None,
) -> list[ModelView]:
"""Get all models in a project."""
"""Get all models in a project, optionally filtered by dataset revision."""
try:
return [ModelView.model_validate(obj, from_attributes=True) for obj in model_service.list_models(project.id)]
return [
ModelView.model_validate(obj, from_attributes=True)
for obj in model_service.list_models(project_id=project.id, dataset_revision_id=dataset_revision_id)
]
except ResourceNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@ def __init__(self, project_id: str, db: Session):
super().__init__(db, ModelRevisionDB)
self.project_id = project_id

def list_all(self) -> Sequence[ModelRevisionDB]:
def list_all(self, training_dataset_id: str | None = None) -> Sequence[ModelRevisionDB]:
"""
List all model revisions for a given project.

Optionally the model revisions can be filtered on training dataset id

Args:
training_dataset_id (str): Optional unique id of the training dataset to filter on

Returns:
Sequence[ModelRevisionDB]: A list of model revisions associated with the project.
"""
stmt = select(ModelRevisionDB).where(ModelRevisionDB.project_id == self.project_id)
if training_dataset_id is not None:
stmt = stmt.where(ModelRevisionDB.training_dataset_id == training_dataset_id)
return self.db.execute(stmt).scalars().all()

def get_by_id(self, obj_id: str) -> ModelRevisionDB | None:
Expand Down
12 changes: 9 additions & 3 deletions application/backend/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def delete_model(self, project_id: UUID, model_id: UUID) -> None:
except IntegrityError:
raise ResourceInUseError(ResourceType.MODEL, str(model_id))

def list_models(self, project_id: UUID) -> list[ModelRevision]:
def list_models(self, project_id: UUID, dataset_revision_id: UUID | None = None) -> list[ModelRevision]:
"""
Get information about all available model revisions in a project.

Expand All @@ -126,13 +126,19 @@ def list_models(self, project_id: UUID) -> list[ModelRevision]:

Args:
project_id (UUID): The unique identifier of the project whose models to list.
dataset_revision_id (UUID | None, optional): The unique identifier of the dataset revision to filter on.

Returns:
list[ModelRevision]: A list of model revision objects representing all model
revisions in the project. Returns an empty list if the project has no models.
revisions in the project, optionally filtered on dataset revision.
Returns an empty list if the project has no models.
"""
model_rev_repo = ModelRevisionRepository(project_id=str(project_id), db=self.db_session)
return [ModelRevision.model_validate(model_rev_db) for model_rev_db in model_rev_repo.list_all()]
training_dataset_id = str(dataset_revision_id) if dataset_revision_id is not None else None
return [
ModelRevision.model_validate(model_rev_db)
for model_rev_db in model_rev_repo.list_all(training_dataset_id=training_dataset_id)
]

def create_revision(self, metadata: ModelRevisionMetadata) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from sqlalchemy.orm import Session

from app.db.schema import ModelRevisionDB, ProjectDB
from app.db.schema import DatasetRevisionDB, ModelRevisionDB, ProjectDB
from app.services import ModelService, ResourceNotFoundError, ResourceType
from tests.integration.project_factory import ProjectTestDataFactory

Expand Down Expand Up @@ -49,6 +49,56 @@ def test_list_models(
for idx in range(len(model_ids)):
assert fxt_db_models[idx].id in model_ids

def test_list_models_with_dataset_revision(
self,
request: pytest.FixtureRequest,
db_session: Session,
fxt_project_id: UUID,
fxt_model_service: ModelService,
):
# Create a dataset revision
dataset_revision_id = uuid4()
dataset_revision = DatasetRevisionDB(
id=str(dataset_revision_id),
project_id=str(fxt_project_id),
files_deleted=False,
)
db_session.add(dataset_revision)
db_session.flush()

# Add an extra model that is linked to the dataset revision
model_id = uuid4()
model = ModelRevisionDB(
id=str(model_id),
name="TestModel",
project_id=str(fxt_project_id),
architecture="TestArch",
parent_revision=None,
training_status="NOT_STARTED",
training_configuration={},
training_dataset_id=str(dataset_revision_id),
label_schema_revision={},
)
db_session.add(model)
db_session.flush()

# Add finalizer to cleanup test data
def cleanup():
db_session.delete(model)
db_session.delete(dataset_revision)
db_session.flush()

request.addfinalizer(cleanup)

# Call list_models with dataset_revision_id and without
dataset_models = fxt_model_service.list_models(fxt_project_id, dataset_revision_id=dataset_revision_id)
all_models = fxt_model_service.list_models(fxt_project_id)

# Only the model linked to the dataset revision should be returned
assert len(all_models) == 3
assert len(dataset_models) == 1
assert str(dataset_models[0].id) == str(model_id)

def test_get_model(self, fxt_project_id: UUID, fxt_model_id: UUID, fxt_model_service: ModelService):
"""Test retrieving a model by ID."""
model = fxt_model_service.get_model(fxt_project_id, fxt_model_id)
Expand Down
Loading