Skip to content

Commit 096d054

Browse files
committed
Fixes with multi project models
1 parent fa6739d commit 096d054

File tree

4 files changed

+593
-62
lines changed

4 files changed

+593
-62
lines changed

.rules/new_models_best_practice.mdc

Lines changed: 310 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,6 @@ def analyze_class_distribution(self, labels):
433433
max_percentage = max(info["percentage"] for info in class_info.values())
434434
if max_percentage > 80:
435435
logger.warning(f"Severe class imbalance detected: {max_percentage:.1f}% majority class")
436-
Right now we need to design architecture and write all points about it down. Your thoughts?
437-
438436

439437
return class_info
440438
```
@@ -605,6 +603,313 @@ def test_extreme_imbalance_handling(self):
605603
assert metrics["balanced_accuracy"] > 0.5, "Balanced accuracy should exceed random chance"
606604
```
607605

606+
## 3.9. Project-Specific Model Isolation for Multi-Tenancy
607+
608+
Enterprise Label Studio deployments often serve multiple projects simultaneously, each requiring isolated models to prevent cross-project interference. Implementing proper project isolation ensures data security, performance independence, and scalability.
609+
610+
### Why Project Isolation Matters
611+
612+
**Data Security and Privacy:**
613+
- Project A's sensitive medical data never trains Project B's financial model
614+
- Each project can have completely different labeling configurations
615+
- Models can't accidentally predict wrong label types from other projects
616+
- Compliance with data governance and privacy requirements
617+
618+
**Performance Independence:**
619+
- Training on one project doesn't affect prediction quality for other projects
620+
- Each project's model optimizes specifically for that project's data characteristics
621+
- Poor annotations in one project won't degrade other projects' models
622+
- Independent model performance metrics and monitoring
623+
624+
**Enterprise Scalability:**
625+
- Memory management keeps frequently used models cached
626+
- Inactive project models are loaded on-demand
627+
- Horizontal scaling across different project workloads
628+
629+
### Implementation Architecture
630+
631+
**Project-Aware Model Storage:**
632+
```python
633+
# Global model cache - project-specific
634+
_models: Dict[int, ModelType] = {}
635+
636+
def _get_model(self, n_channels: int, n_labels: int, project_id: Optional[int] = None, blank: bool = False) -> ModelType:
637+
"""Get or create model for specific project."""
638+
global _models
639+
640+
# Use default project_id if not provided (backward compatibility)
641+
if project_id is None:
642+
project_id = 0
643+
logger.warning("No project_id provided, using default project_id=0")
644+
645+
# Check memory cache first
646+
if project_id in _models and not blank:
647+
logger.info(f"Using existing model for project {project_id} from memory")
648+
return _models[project_id]
649+
650+
# Try loading from project-specific file
651+
model_path = os.path.join(self.MODEL_DIR, f"model_project_{project_id}.pt")
652+
653+
if not blank and os.path.exists(model_path):
654+
logger.info(f"Loading saved model for project {project_id} from {model_path}")
655+
try:
656+
model = ModelType.load_model(model_path)
657+
_models[project_id] = model
658+
return model
659+
except Exception as e:
660+
logger.warning(f"Failed to load model from {model_path}: {e}. Creating new model.")
661+
# Clean up corrupted file
662+
os.remove(model_path)
663+
664+
# Create new model for this project
665+
logger.info(f"Creating new model for project {project_id}")
666+
model = self._build_model(n_channels, n_labels)
667+
_models[project_id] = model
668+
669+
return model
670+
```
671+
672+
**Project-Specific File Management:**
673+
```python
674+
def _save_model(self, model: ModelType, project_id: Optional[int] = None) -> None:
675+
"""Save model with project-specific naming."""
676+
if project_id is None:
677+
project_id = 0
678+
logger.warning("No project_id provided for model save, using default project_id=0")
679+
680+
logger.info(f"Saving model for project {project_id} to {self.MODEL_DIR}")
681+
os.makedirs(self.MODEL_DIR, exist_ok=True)
682+
683+
# Project-specific file naming
684+
model_path = os.path.join(self.MODEL_DIR, f"model_project_{project_id}.pt")
685+
model.save(model_path)
686+
logger.info(f"Model for project {project_id} saved successfully to {model_path}")
687+
688+
def _clear_project_cache(self, project_id: int) -> None:
689+
"""Clear specific project from memory cache."""
690+
global _models
691+
if project_id in _models:
692+
del _models[project_id]
693+
logger.info(f"Model cache cleared for project {project_id}")
694+
```
695+
696+
### Project ID Detection and Context Handling
697+
698+
**Automatic Project ID Extraction:**
699+
```python
700+
def _get_project_id_from_context(self, tasks: List[Dict], context: Optional[Dict] = None) -> Optional[int]:
701+
"""Extract project ID from tasks or context for model selection."""
702+
# Try context first - most reliable source
703+
if context and "project" in context:
704+
if isinstance(context["project"], dict) and "id" in context["project"]:
705+
project_id = context["project"]["id"]
706+
logger.debug(f"Found project_id {project_id} from context dict")
707+
return project_id
708+
elif isinstance(context["project"], (int, str)):
709+
project_id = int(context["project"])
710+
logger.debug(f"Found project_id {project_id} from context value")
711+
return project_id
712+
713+
# Fall back to task metadata
714+
for task in tasks:
715+
if "project" in task:
716+
project_id = int(task["project"])
717+
logger.debug(f"Found project_id {project_id} from task")
718+
return project_id
719+
720+
logger.debug("No project_id found in tasks or context")
721+
return None
722+
```
723+
724+
**Training with Project Awareness:**
725+
```python
726+
def fit(self, event, data, **kwargs):
727+
"""Train model with project isolation."""
728+
logger.info(f"Training event received: {event}")
729+
730+
# Extract project ID from training event
731+
project_id = data["annotation"]["project"]
732+
logger.info(f"Training triggered for project {project_id}")
733+
734+
# Get project-specific model
735+
model = self._get_model(
736+
n_channels=len(params["channels"]),
737+
n_labels=len(params["all_labels"]),
738+
project_id=project_id,
739+
blank=True # Create fresh model for training
740+
)
741+
742+
# Train model with project-specific data
743+
metrics = model.partial_fit(X, y, epochs=self.TRAIN_EPOCHS)
744+
745+
# Save with project-specific naming
746+
self._save_model(model, project_id=project_id)
747+
748+
# Clear cache to force reload
749+
self._clear_project_cache(project_id)
750+
751+
return metrics
752+
```
753+
754+
**Prediction with Project Awareness:**
755+
```python
756+
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
757+
"""Predict using project-specific model."""
758+
logger.info(f"Starting prediction for {len(tasks)} tasks")
759+
760+
# Determine which project's model to use
761+
project_id = self._get_project_id_from_context(tasks, context)
762+
if project_id is not None:
763+
logger.info(f"Using model for project {project_id}")
764+
else:
765+
logger.info("No project_id found, using default model")
766+
767+
# Load project-specific model
768+
params = self._get_labeling_params()
769+
model = self._get_model(
770+
n_channels=len(params["channels"]),
771+
n_labels=len(params["all_labels"]),
772+
project_id=project_id
773+
)
774+
775+
# Generate predictions with project-specific model
776+
predictions = []
777+
for task in tasks:
778+
pred = self._predict_task(task, model, params)
779+
predictions.append(pred)
780+
781+
return ModelResponse(predictions=predictions, model_version=self.get("model_version"))
782+
```
783+
784+
### Memory Management and Performance
785+
786+
**Efficient Caching Strategy:**
787+
```python
788+
# Cache frequently used models in memory
789+
MAX_CACHED_MODELS = int(os.getenv("MAX_CACHED_MODELS", 5))
790+
791+
def _manage_model_cache(self, project_id: int, model: ModelType) -> None:
792+
"""Manage memory cache with LRU eviction."""
793+
global _models
794+
795+
# Add to cache
796+
_models[project_id] = model
797+
798+
# Implement simple LRU eviction if cache is full
799+
if len(_models) > MAX_CACHED_MODELS:
800+
# Remove oldest model (simple FIFO for now)
801+
oldest_project = next(iter(_models))
802+
del _models[oldest_project]
803+
logger.info(f"Evicted model for project {oldest_project} from cache")
804+
```
805+
806+
**Resource Monitoring:**
807+
```python
808+
def log_model_cache_status(self):
809+
"""Log current model cache status for monitoring."""
810+
global _models
811+
cached_projects = list(_models.keys())
812+
logger.info(f"Model cache status: {len(cached_projects)} projects cached: {cached_projects}")
813+
```
814+
815+
816+
### Configuration and Environment Variables
817+
818+
**Project Isolation Settings:**
819+
```python
820+
# Environment variables for project isolation
821+
MAX_CACHED_MODELS = int(os.getenv("MAX_CACHED_MODELS", 5))
822+
ENABLE_PROJECT_ISOLATION = os.getenv("ENABLE_PROJECT_ISOLATION", "true").lower() == "true"
823+
DEFAULT_PROJECT_ID = int(os.getenv("DEFAULT_PROJECT_ID", 0))
824+
825+
def setup(self):
826+
"""Setup with project isolation configuration."""
827+
logger.info(f"Project isolation: enabled={self.ENABLE_PROJECT_ISOLATION}, "
828+
f"max_cached_models={self.MAX_CACHED_MODELS}, "
829+
f"default_project_id={self.DEFAULT_PROJECT_ID}")
830+
```
831+
832+
### Testing Project Isolation
833+
834+
**Comprehensive Project Isolation Tests:**
835+
```python
836+
def test_project_specific_models(self):
837+
"""Test that different projects use separate models and model files."""
838+
# Create models for different projects
839+
model_project_1 = segmenter._get_model(n_channels=2, n_labels=3, project_id=1)
840+
model_project_2 = segmenter._get_model(n_channels=2, n_labels=3, project_id=2)
841+
model_default = segmenter._get_model(n_channels=2, n_labels=3) # project_id=0
842+
843+
# Verify different instances
844+
assert model_project_1 is not model_project_2
845+
assert model_project_1 is not model_default
846+
assert model_project_2 is not model_default
847+
848+
# Test project-specific file naming
849+
segmenter._save_model(model_project_1, project_id=1)
850+
segmenter._save_model(model_project_2, project_id=2)
851+
852+
assert os.path.exists(os.path.join(temp_dir, "model_project_1.pt"))
853+
assert os.path.exists(os.path.join(temp_dir, "model_project_2.pt"))
854+
855+
# Test project ID extraction from context
856+
context_dict = {"project": {"id": 42}}
857+
project_id = segmenter._get_project_id_from_context([], context_dict)
858+
assert project_id == 42
859+
860+
context_int = {"project": 99}
861+
project_id = segmenter._get_project_id_from_context([], context_int)
862+
assert project_id == 99
863+
864+
def test_project_isolation_prevents_cross_contamination(self):
865+
"""Test that training one project doesn't affect another."""
866+
# Train model for project 1
867+
task_p1 = create_task_with_project(project_id=1, labels=["ClassA", "ClassB"])
868+
segmenter.fit("START_TRAINING", {"annotation": {"project": 1}}, tasks=[task_p1])
869+
870+
# Train different model for project 2
871+
task_p2 = create_task_with_project(project_id=2, labels=["ClassX", "ClassY"])
872+
segmenter.fit("START_TRAINING", {"annotation": {"project": 2}}, tasks=[task_p2])
873+
874+
# Verify predictions use correct project models
875+
pred_p1 = segmenter.predict([task_p1], context={"project": 1})
876+
pred_p2 = segmenter.predict([task_p2], context={"project": 2})
877+
878+
# Models should predict different label sets
879+
assert_different_label_predictions(pred_p1, pred_p2)
880+
```
881+
882+
### Production Deployment Considerations
883+
884+
**Docker Configuration:**
885+
```yaml
886+
# docker-compose.yml for multi-tenant deployment
887+
services:
888+
ml-backend:
889+
environment:
890+
- MAX_CACHED_MODELS=10 # Adjust based on memory
891+
- ENABLE_PROJECT_ISOLATION=true
892+
- MODEL_DIR=/app/models
893+
volumes:
894+
- ./models:/app/models # Persistent model storage
895+
```
896+
897+
**Monitoring and Alerting:**
898+
```python
899+
def health_check_with_project_info(self):
900+
"""Health check endpoint with project isolation status."""
901+
global _models
902+
return {
903+
"status": "healthy",
904+
"project_isolation_enabled": self.ENABLE_PROJECT_ISOLATION,
905+
"cached_projects": list(_models.keys()),
906+
"max_cache_size": self.MAX_CACHED_MODELS,
907+
"cache_utilization": len(_models) / self.MAX_CACHED_MODELS
908+
}
909+
```
910+
911+
This project isolation pattern ensures enterprise-grade multi-tenancy while maintaining backward compatibility and providing the scalability needed for production Label Studio deployments serving multiple teams or clients.
912+
608913
## 4. Testing
609914

610915
- Tests should be runnable with `pytest` directly from the repository root or inside the example's Docker container.
@@ -892,6 +1197,7 @@ def test_model_training_workflow(self):
8921197
- **`label_studio_ml/examples/yolo/`** - Well-structured computer vision backend with good Docker integration
8931198
- **`label_studio_ml/examples/timeseries_segmenter/`** - Comprehensive ML backend demonstrating:
8941199
- Advanced imbalanced data handling with class weights and balanced metrics
1200+
- Project-specific model isolation for multi-tenant deployments
8951201
- Proper PyTorch model serialization and loading
8961202
- ML-specific testing patterns with comprehensive test suite
8971203
- Annotation semantics handling (instant vs range annotations)
@@ -904,6 +1210,8 @@ def test_model_training_workflow(self):
9041210

9051211
**For ML backends with imbalanced data**: Use `timeseries_segmenter/` as a reference for balanced learning approaches, advanced training patterns, and comprehensive testing.
9061212

1213+
**For enterprise/multi-tenant deployments**: Use `timeseries_segmenter/` as a reference for project-specific model isolation, ensuring proper data security and performance independence across multiple Label Studio projects.
1214+
9071215
**For any ML backend**: Both examples demonstrate solid project structure, error handling, and documentation practices.
9081216

9091217
Following these conventions helps maintain consistency across examples and makes it easier for contributors and automation tools to understand each backend.

0 commit comments

Comments
 (0)