Skip to content

Commit

Permalink
editoast: add simulation summary and simulation endpoints for paced t…
Browse files Browse the repository at this point in the history
…rain

Signed-off-by: Youness CHRIFI ALAOUI <[email protected]>
  • Loading branch information
younesschrifi committed Mar 4, 2025
1 parent b7a0b08 commit be809ed
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 67 deletions.
52 changes: 51 additions & 1 deletion editoast/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5139,6 +5139,8 @@ components:
- $ref: '#/components/schemas/EditoastOperationErrorObjectNotFound'
- $ref: '#/components/schemas/EditoastPacedTrainErrorBatchPacedTrainNotFound'
- $ref: '#/components/schemas/EditoastPacedTrainErrorDatabase'
- $ref: '#/components/schemas/EditoastPacedTrainErrorInfraNotFound'
- $ref: '#/components/schemas/EditoastPacedTrainErrorNotFound'
- $ref: '#/components/schemas/EditoastPaginationErrorInvalidPage'
- $ref: '#/components/schemas/EditoastPaginationErrorInvalidPageSize'
- $ref: '#/components/schemas/EditoastPathfindingErrorInfraNotFound'
Expand Down Expand Up @@ -5676,6 +5678,54 @@ components:
type: string
enum:
- editoast:paced_train:Database
EditoastPacedTrainErrorInfraNotFound:
type: object
required:
- type
- status
- message
properties:
context:
type: object
required:
- infra_id
properties:
infra_id:
type: integer
message:
type: string
status:
type: integer
enum:
- 404
type:
type: string
enum:
- editoast:paced_train:InfraNotFound
EditoastPacedTrainErrorNotFound:
type: object
required:
- type
- status
- message
properties:
context:
type: object
required:
- paced_train_id
properties:
paced_train_id:
type: integer
message:
type: string
status:
type: integer
enum:
- 404
type:
type: string
enum:
- editoast:paced_train:NotFound
EditoastPaginationErrorInvalidPage:
type: object
required:
Expand Down Expand Up @@ -8542,7 +8592,7 @@ components:
timetable_id:
type: integer
format: int64
description: Timetable attached to the train schedule
description: Timetable attached to the paced train
nullable: true
PacedTrainResult:
allOf:
Expand Down
23 changes: 23 additions & 0 deletions editoast/src/models/paced_train.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::models::train_schedule::TrainSchedule;
use chrono::DateTime;
use chrono::Duration as ChronoDuration;
use chrono::Utc;
Expand Down Expand Up @@ -101,3 +102,25 @@ impl From<PacedTrain> for PacedTrainBase {
}
}
}

impl From<PacedTrain> for TrainSchedule {
fn from(paced_train: PacedTrain) -> Self {
Self {
id: paced_train.id,
train_name: paced_train.train_name,
labels: paced_train.labels.into(),
rolling_stock_name: paced_train.rolling_stock_name,
timetable_id: paced_train.timetable_id,
path: paced_train.path,
start_time: paced_train.start_time,
schedule: paced_train.schedule,
margins: paced_train.margins,
initial_speed: paced_train.initial_speed,
comfort: paced_train.comfort,
constraint_distribution: paced_train.constraint_distribution,
speed_limit_tag: paced_train.speed_limit_tag,
power_restrictions: paced_train.power_restrictions,
options: paced_train.options,
}
}
}
2 changes: 1 addition & 1 deletion editoast/src/views/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ pub struct InfraIdQueryParam {

#[derive(Debug, Serialize, ToSchema)]
#[serde(tag = "status", rename_all = "snake_case")]
enum SimulationSummaryResult {
pub enum SimulationSummaryResult {
/// Minimal information on a simulation's result
Success {
/// Length of a path in mm
Expand Down
177 changes: 156 additions & 21 deletions editoast/src/views/paced_train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::collections::HashSet;
use crate::core::simulation::SimulationResponse;
use crate::error::Result;
use crate::models::prelude::*;
use crate::models::train_schedule::TrainSchedule;
use crate::models::Infra;
use crate::views::projection::ProjectPathTrainResult;
use crate::views::ListId;
use axum::extract::Json;
Expand All @@ -15,13 +17,16 @@ use editoast_authz::BuiltinRole;
use editoast_derive::EditoastError;
use editoast_models::DbConnectionPoolV2;
use editoast_schemas::paced_train::PacedTrainBase;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use utoipa::IntoParams;
use utoipa::ToSchema;

use super::path::pathfinding::PathfindingResult;
use super::projection::ProjectPathForm;
use super::train_schedule::simulation_response;
use super::train_schedule::train_simulation_batch;
use super::AppState;
use super::AuthenticationExt;
use super::InfraIdQueryParam;
Expand Down Expand Up @@ -54,14 +59,20 @@ enum PacedTrainError {
#[error("{count} paced train(s) could not be found")]
#[editoast_error(status = 404)]
BatchPacedTrainNotFound { count: usize },
#[error("Paced train '{paced_train_id}', could not be found")]
#[editoast_error(status = 404)]
NotFound { paced_train_id: i64 },
#[error("Infra '{infra_id}', could not be found")]
#[editoast_error(status = 404)]
InfraNotFound { infra_id: i64 },
#[error(transparent)]
#[editoast_error(status = 500)]
Database(#[from] editoast_models::model::Error),
}

#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct PacedTrainForm {
/// Timetable attached to the train schedule
/// Timetable attached to the paced train
pub timetable_id: Option<i64>,
#[serde(flatten)]
pub paced_train_base: PacedTrainBase,
Expand Down Expand Up @@ -181,19 +192,61 @@ struct SimulationBatchForm {
)]
async fn simulation_summary(
State(AppState {
db_pool: _db_pool,
valkey: _valkey_client,
core_client: _core,
db_pool,
valkey: valkey_client,
core_client: core,
..
}): State<AppState>,
Extension(_auth): AuthenticationExt,
Extension(auth): AuthenticationExt,
Json(SimulationBatchForm {
infra_id: _infra_id,
electrical_profile_set_id: _electrical_profile_set_id,
ids: _paced_train_ids,
infra_id,
electrical_profile_set_id,
ids: paced_train_ids,
}): Json<SimulationBatchForm>,
) -> Result<Json<HashMap<i64, SimulationSummaryResult>>> {
todo!();
let authorized = auth
.check_roles([BuiltinRole::InfraRead, BuiltinRole::TimetableRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Forbidden.into());
}

let conn = &mut db_pool.get().await?;

let infra = Infra::retrieve_or_fail(conn, infra_id, || PacedTrainError::InfraNotFound {
infra_id,
})
.await?;

let paced_trains: Vec<PacedTrain> =
PacedTrain::retrieve_batch_or_fail(conn, paced_train_ids, |missing| {
PacedTrainError::BatchPacedTrainNotFound {
count: missing.len(),
}
})
.await?;
let paced_trains_to_ts: Vec<TrainSchedule> = paced_trains.clone().into_iter().map_into().collect();

let simulations = train_simulation_batch(
conn,
valkey_client,
core,
&paced_trains_to_ts,
&infra,
electrical_profile_set_id,
)
.await?;

// Transform simulations to simulation summary
let mut simulation_summaries = HashMap::new();
for (paced_train, sim) in paced_trains.into_iter().zip(simulations) {
let (sim, _) = sim;
let simulation_summary_result = simulation_response(sim);
simulation_summaries.insert(paced_train.id, simulation_summary_result);
}

Ok(Json(simulation_summaries))
}

/// Get a path from a paced train given an infrastructure id and a paced train id
Expand Down Expand Up @@ -241,23 +294,53 @@ pub struct ElectricalProfileSetIdQueryParam {
)]
async fn simulation(
State(AppState {
valkey: _valkey_client,
core_client: _core_client,
db_pool: _db_pool,
valkey: valkey_client,
core_client,
db_pool,
..
}): State<AppState>,
Extension(_auth): AuthenticationExt,
Path(PacedTrainIdParam {
id: _paced_train_id,
}): Path<PacedTrainIdParam>,
Query(InfraIdQueryParam {
infra_id: _infra_id,
}): Query<InfraIdQueryParam>,
Extension(auth): AuthenticationExt,
Path(PacedTrainIdParam { id: paced_train_id }): Path<PacedTrainIdParam>,
Query(InfraIdQueryParam { infra_id }): Query<InfraIdQueryParam>,
Query(ElectricalProfileSetIdQueryParam {
electrical_profile_set_id: _electrical_profile_set_id,
electrical_profile_set_id,
}): Query<ElectricalProfileSetIdQueryParam>,
) -> Result<Json<SimulationResponse>> {
todo!();
let authorized = auth
.check_roles([BuiltinRole::InfraRead, BuiltinRole::TimetableRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Forbidden.into());
}

// Retrieve infra or fail
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
PacedTrainError::InfraNotFound { infra_id }
})
.await?;

// Retrieve paced_train or fail
let paced_train =
PacedTrain::retrieve_or_fail(&mut db_pool.get().await?, paced_train_id, || {
PacedTrainError::NotFound { paced_train_id }
})
.await?;

// Compute simulation of a paced_train
let (simulation, _) = train_simulation_batch(
&mut db_pool.get().await?,
valkey_client,
core_client,
&[paced_train.into()],
&infra,
electrical_profile_set_id,
)
.await?
.pop()
.unwrap();

Ok(Json(simulation))
}

/// Projects the space time curves and paths of a number of paced trains onto a given path
Expand Down Expand Up @@ -294,17 +377,26 @@ async fn project_path(

#[cfg(test)]
mod tests {
use chrono::Duration;
use editoast_models::DbConnectionPoolV2;
use editoast_schemas::paced_train::{Paced, PacedTrainBase};
use editoast_schemas::train_schedule::TrainScheduleBase;
use rstest::rstest;
use serde_json::json;

use crate::models::paced_train::PacedTrainChangeset;
use crate::models::prelude::*;
use crate::views::test_app::TestApp;
use crate::views::train_schedule::tests::mocked_core_pathfinding_sim_and_proj;
use crate::{
models::{
fixtures::{create_simple_paced_train, create_timetable, simple_paced_train_base},
paced_train::PacedTrain,
},
views::{paced_train::PacedTrainResult, test_app::TestAppBuilder},
};
use crate::models::fixtures::create_small_infra;
use crate::models::fixtures::create_fast_rolling_stock;
use axum::http::StatusCode;

#[rstest]
Expand Down Expand Up @@ -344,4 +436,47 @@ mod tests {

assert!(!exists);
}

async fn app_infra_id_paced_train_id_for_simulation_tests() -> (TestApp, i64, i64) {
let db_pool = DbConnectionPoolV2::for_tests();
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
let rolling_stock =
create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await;
let paced_train_base: PacedTrainBase = PacedTrainBase {
train_schedule_base: TrainScheduleBase {
rolling_stock_name: rolling_stock.name.clone(),
..serde_json::from_str(include_str!("../tests/train_schedules/simple.json"))
.expect("Unable to parse")
},
paced: Paced {
duration: Duration::hours(1).try_into().unwrap(),
step: Duration::minutes(15).try_into().unwrap(),
},
};
let paced_train: PacedTrainChangeset = paced_train_base.into();
let paced_train = paced_train
.create(&mut db_pool.get_ok())
.await
.expect("Failed to create paced train");
let core = mocked_core_pathfinding_sim_and_proj(paced_train.id);
let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
.core_client(core.into())
.build();
(app, small_infra.id, paced_train.id)
}

#[rstest]
async fn paced_train_simulation() {
let (app, infra_id, train_schedule_id) =
app_infra_id_paced_train_id_for_simulation_tests().await;
let request = app.get(
format!(
"/paced_train/{}/simulation/?infra_id={}",
train_schedule_id, infra_id
)
.as_str(),
);
app.fetch(request).assert_status(StatusCode::OK);
}
}
2 changes: 1 addition & 1 deletion editoast/src/views/timetable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ async fn paced_train(
let changesets = paced_trains
.into_iter()
.map(PacedTrainChangeset::from)
.map(|cs| cs.timetable_id(timetable_id))
.map(|cs: PacedTrainChangeset| cs.timetable_id(timetable_id))
.collect::<Vec<_>>();

// Create a batch of paced trains
Expand Down
Loading

0 comments on commit be809ed

Please sign in to comment.