|
| 1 | +# path: policylens/apps/claims/management/commands/train_completeness_model.py |
| 2 | +""" |
| 3 | +Train the completeness classifier and save a versioned model bundle. |
| 4 | +
|
| 5 | +This command: |
| 6 | +- Reads a synthetic dataset CSV (or any CSV matching the contract) |
| 7 | +- Trains a lightweight model |
| 8 | +- Saves artefacts to artifacts/ml/<version>/ |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +from pathlib import Path |
| 14 | + |
| 15 | +from django.core.management.base import BaseCommand |
| 16 | + |
| 17 | +from apps.claims.ml.train import train_from_csv |
| 18 | + |
| 19 | + |
| 20 | +class Command(BaseCommand): |
| 21 | + """Train the completeness model.""" |
| 22 | + |
| 23 | + help = "Train completeness classifier from a contract-aligned CSV dataset." |
| 24 | + |
| 25 | + def add_arguments(self, parser) -> None: |
| 26 | + parser.add_argument("--csv", required=True, help="Input dataset path.") |
| 27 | + parser.add_argument("--version", required=True, help="Model version folder name, e.g. v1_2026_01_13.") |
| 28 | + parser.add_argument("--threshold", type=float, default=0.6, help="Threshold for likely incomplete.") |
| 29 | + parser.add_argument("--seed", type=int, default=42, help="Deterministic training seed.") |
| 30 | + |
| 31 | + def handle(self, *args, **options) -> None: |
| 32 | + csv_path = Path(options["csv"]) |
| 33 | + version = str(options["version"]) |
| 34 | + threshold = float(options["threshold"]) |
| 35 | + seed = int(options["seed"]) |
| 36 | + |
| 37 | + result = train_from_csv(csv_path=csv_path, model_version=version, threshold=threshold, random_seed=seed) |
| 38 | + self.stdout.write(self.style.SUCCESS(f"Trained model {result.model_version}")) |
| 39 | + self.stdout.write(self.style.SUCCESS(f"Metrics: {result.metrics}")) |
0 commit comments