@@ -4,6 +4,8 @@ use std::collections::HashSet;
4
4
use crate :: core:: simulation:: SimulationResponse ;
5
5
use crate :: error:: Result ;
6
6
use crate :: models:: prelude:: * ;
7
+ use crate :: models:: train_schedule:: TrainSchedule ;
8
+ use crate :: models:: Infra ;
7
9
use crate :: views:: projection:: ProjectPathTrainResult ;
8
10
use crate :: views:: ListId ;
9
11
use axum:: extract:: Json ;
@@ -15,13 +17,16 @@ use editoast_authz::BuiltinRole;
15
17
use editoast_derive:: EditoastError ;
16
18
use editoast_models:: DbConnectionPoolV2 ;
17
19
use editoast_schemas:: paced_train:: PacedTrainBase ;
20
+ use itertools:: Itertools ;
18
21
use serde:: { Deserialize , Serialize } ;
19
22
use thiserror:: Error ;
20
23
use utoipa:: IntoParams ;
21
24
use utoipa:: ToSchema ;
22
25
23
26
use super :: path:: pathfinding:: PathfindingResult ;
24
27
use super :: projection:: ProjectPathForm ;
28
+ use super :: train_schedule:: simulation_response;
29
+ use super :: train_schedule:: train_simulation_batch;
25
30
use super :: AppState ;
26
31
use super :: AuthenticationExt ;
27
32
use super :: InfraIdQueryParam ;
@@ -54,14 +59,20 @@ enum PacedTrainError {
54
59
#[ error( "{count} paced train(s) could not be found" ) ]
55
60
#[ editoast_error( status = 404 ) ]
56
61
BatchPacedTrainNotFound { count : usize } ,
62
+ #[ error( "Paced train '{paced_train_id}', could not be found" ) ]
63
+ #[ editoast_error( status = 404 ) ]
64
+ NotFound { paced_train_id : i64 } ,
65
+ #[ error( "Infra '{infra_id}', could not be found" ) ]
66
+ #[ editoast_error( status = 404 ) ]
67
+ InfraNotFound { infra_id : i64 } ,
57
68
#[ error( transparent) ]
58
69
#[ editoast_error( status = 500 ) ]
59
70
Database ( #[ from] editoast_models:: model:: Error ) ,
60
71
}
61
72
62
73
#[ derive( Debug , Clone , Serialize , Deserialize , ToSchema ) ]
63
74
pub struct PacedTrainForm {
64
- /// Timetable attached to the train schedule
75
+ /// Timetable attached to the paced train
65
76
pub timetable_id : Option < i64 > ,
66
77
#[ serde( flatten) ]
67
78
pub paced_train_base : PacedTrainBase ,
@@ -181,19 +192,61 @@ struct SimulationBatchForm {
181
192
) ]
182
193
async fn simulation_summary (
183
194
State ( AppState {
184
- db_pool : _db_pool ,
185
- valkey : _valkey_client ,
186
- core_client : _core ,
195
+ db_pool,
196
+ valkey : valkey_client ,
197
+ core_client : core ,
187
198
..
188
199
} ) : State < AppState > ,
189
- Extension ( _auth ) : AuthenticationExt ,
200
+ Extension ( auth ) : AuthenticationExt ,
190
201
Json ( SimulationBatchForm {
191
- infra_id : _infra_id ,
192
- electrical_profile_set_id : _electrical_profile_set_id ,
193
- ids : _paced_train_ids ,
202
+ infra_id,
203
+ electrical_profile_set_id,
204
+ ids : paced_train_ids ,
194
205
} ) : Json < SimulationBatchForm > ,
195
206
) -> Result < Json < HashMap < i64 , SimulationSummaryResult > > > {
196
- todo ! ( ) ;
207
+ let authorized = auth
208
+ . check_roles ( [ BuiltinRole :: InfraRead , BuiltinRole :: TimetableRead ] . into ( ) )
209
+ . await
210
+ . map_err ( AuthorizationError :: AuthError ) ?;
211
+ if !authorized {
212
+ return Err ( AuthorizationError :: Forbidden . into ( ) ) ;
213
+ }
214
+
215
+ let conn = & mut db_pool. get ( ) . await ?;
216
+
217
+ let infra = Infra :: retrieve_or_fail ( conn, infra_id, || PacedTrainError :: InfraNotFound {
218
+ infra_id,
219
+ } )
220
+ . await ?;
221
+
222
+ let paced_trains: Vec < PacedTrain > =
223
+ PacedTrain :: retrieve_batch_or_fail ( conn, paced_train_ids, |missing| {
224
+ PacedTrainError :: BatchPacedTrainNotFound {
225
+ count : missing. len ( ) ,
226
+ }
227
+ } )
228
+ . await ?;
229
+ let paced_trains_to_ts: Vec < TrainSchedule > = paced_trains. clone ( ) . into_iter ( ) . map_into ( ) . collect ( ) ;
230
+
231
+ let simulations = train_simulation_batch (
232
+ conn,
233
+ valkey_client,
234
+ core,
235
+ & paced_trains_to_ts,
236
+ & infra,
237
+ electrical_profile_set_id,
238
+ )
239
+ . await ?;
240
+
241
+ // Transform simulations to simulation summary
242
+ let mut simulation_summaries = HashMap :: new ( ) ;
243
+ for ( paced_train, sim) in paced_trains. into_iter ( ) . zip ( simulations) {
244
+ let ( sim, _) = sim;
245
+ let simulation_summary_result = simulation_response ( sim) ;
246
+ simulation_summaries. insert ( paced_train. id , simulation_summary_result) ;
247
+ }
248
+
249
+ Ok ( Json ( simulation_summaries) )
197
250
}
198
251
199
252
/// Get a path from a paced train given an infrastructure id and a paced train id
@@ -241,23 +294,53 @@ pub struct ElectricalProfileSetIdQueryParam {
241
294
) ]
242
295
async fn simulation (
243
296
State ( AppState {
244
- valkey : _valkey_client ,
245
- core_client : _core_client ,
246
- db_pool : _db_pool ,
297
+ valkey : valkey_client ,
298
+ core_client,
299
+ db_pool,
247
300
..
248
301
} ) : State < AppState > ,
249
- Extension ( _auth) : AuthenticationExt ,
250
- Path ( PacedTrainIdParam {
251
- id : _paced_train_id,
252
- } ) : Path < PacedTrainIdParam > ,
253
- Query ( InfraIdQueryParam {
254
- infra_id : _infra_id,
255
- } ) : Query < InfraIdQueryParam > ,
302
+ Extension ( auth) : AuthenticationExt ,
303
+ Path ( PacedTrainIdParam { id : paced_train_id } ) : Path < PacedTrainIdParam > ,
304
+ Query ( InfraIdQueryParam { infra_id } ) : Query < InfraIdQueryParam > ,
256
305
Query ( ElectricalProfileSetIdQueryParam {
257
- electrical_profile_set_id : _electrical_profile_set_id ,
306
+ electrical_profile_set_id,
258
307
} ) : Query < ElectricalProfileSetIdQueryParam > ,
259
308
) -> Result < Json < SimulationResponse > > {
260
- todo ! ( ) ;
309
+ let authorized = auth
310
+ . check_roles ( [ BuiltinRole :: InfraRead , BuiltinRole :: TimetableRead ] . into ( ) )
311
+ . await
312
+ . map_err ( AuthorizationError :: AuthError ) ?;
313
+ if !authorized {
314
+ return Err ( AuthorizationError :: Forbidden . into ( ) ) ;
315
+ }
316
+
317
+ // Retrieve infra or fail
318
+ let infra = Infra :: retrieve_or_fail ( & mut db_pool. get ( ) . await ?, infra_id, || {
319
+ PacedTrainError :: InfraNotFound { infra_id }
320
+ } )
321
+ . await ?;
322
+
323
+ // Retrieve paced_train or fail
324
+ let paced_train =
325
+ PacedTrain :: retrieve_or_fail ( & mut db_pool. get ( ) . await ?, paced_train_id, || {
326
+ PacedTrainError :: NotFound { paced_train_id }
327
+ } )
328
+ . await ?;
329
+
330
+ // Compute simulation of a paced_train
331
+ let ( simulation, _) = train_simulation_batch (
332
+ & mut db_pool. get ( ) . await ?,
333
+ valkey_client,
334
+ core_client,
335
+ & [ paced_train. into ( ) ] ,
336
+ & infra,
337
+ electrical_profile_set_id,
338
+ )
339
+ . await ?
340
+ . pop ( )
341
+ . unwrap ( ) ;
342
+
343
+ Ok ( Json ( simulation) )
261
344
}
262
345
263
346
/// Projects the space time curves and paths of a number of paced trains onto a given path
@@ -294,17 +377,26 @@ async fn project_path(
294
377
295
378
#[ cfg( test) ]
296
379
mod tests {
380
+ use chrono:: Duration ;
381
+ use editoast_models:: DbConnectionPoolV2 ;
382
+ use editoast_schemas:: paced_train:: { Paced , PacedTrainBase } ;
383
+ use editoast_schemas:: train_schedule:: TrainScheduleBase ;
297
384
use rstest:: rstest;
298
385
use serde_json:: json;
299
386
387
+ use crate :: models:: paced_train:: PacedTrainChangeset ;
300
388
use crate :: models:: prelude:: * ;
389
+ use crate :: views:: test_app:: TestApp ;
390
+ use crate :: views:: train_schedule:: tests:: mocked_core_pathfinding_sim_and_proj;
301
391
use crate :: {
302
392
models:: {
303
393
fixtures:: { create_simple_paced_train, create_timetable, simple_paced_train_base} ,
304
394
paced_train:: PacedTrain ,
305
395
} ,
306
396
views:: { paced_train:: PacedTrainResult , test_app:: TestAppBuilder } ,
307
397
} ;
398
+ use crate :: models:: fixtures:: create_small_infra;
399
+ use crate :: models:: fixtures:: create_fast_rolling_stock;
308
400
use axum:: http:: StatusCode ;
309
401
310
402
#[ rstest]
@@ -344,4 +436,47 @@ mod tests {
344
436
345
437
assert ! ( !exists) ;
346
438
}
439
+
440
+ async fn app_infra_id_paced_train_id_for_simulation_tests ( ) -> ( TestApp , i64 , i64 ) {
441
+ let db_pool = DbConnectionPoolV2 :: for_tests ( ) ;
442
+ let small_infra = create_small_infra ( & mut db_pool. get_ok ( ) ) . await ;
443
+ let rolling_stock =
444
+ create_fast_rolling_stock ( & mut db_pool. get_ok ( ) , "simulation_rolling_stock" ) . await ;
445
+ let paced_train_base: PacedTrainBase = PacedTrainBase {
446
+ train_schedule_base : TrainScheduleBase {
447
+ rolling_stock_name : rolling_stock. name . clone ( ) ,
448
+ ..serde_json:: from_str ( include_str ! ( "../tests/train_schedules/simple.json" ) )
449
+ . expect ( "Unable to parse" )
450
+ } ,
451
+ paced : Paced {
452
+ duration : Duration :: hours ( 1 ) . try_into ( ) . unwrap ( ) ,
453
+ step : Duration :: minutes ( 15 ) . try_into ( ) . unwrap ( ) ,
454
+ } ,
455
+ } ;
456
+ let paced_train: PacedTrainChangeset = paced_train_base. into ( ) ;
457
+ let paced_train = paced_train
458
+ . create ( & mut db_pool. get_ok ( ) )
459
+ . await
460
+ . expect ( "Failed to create paced train" ) ;
461
+ let core = mocked_core_pathfinding_sim_and_proj ( paced_train. id ) ;
462
+ let app = TestAppBuilder :: new ( )
463
+ . db_pool ( db_pool. clone ( ) )
464
+ . core_client ( core. into ( ) )
465
+ . build ( ) ;
466
+ ( app, small_infra. id , paced_train. id )
467
+ }
468
+
469
+ #[ rstest]
470
+ async fn paced_train_simulation ( ) {
471
+ let ( app, infra_id, train_schedule_id) =
472
+ app_infra_id_paced_train_id_for_simulation_tests ( ) . await ;
473
+ let request = app. get (
474
+ format ! (
475
+ "/paced_train/{}/simulation/?infra_id={}" ,
476
+ train_schedule_id, infra_id
477
+ )
478
+ . as_str ( ) ,
479
+ ) ;
480
+ app. fetch ( request) . assert_status ( StatusCode :: OK ) ;
481
+ }
347
482
}
0 commit comments