11import  pandas  as  pd 
22
3- from  sampo .generator .environment  import  ContractorGenerationMethod 
3+ from  sampo .generator .environment  import  ContractorGenerationMethod ,  get_contractor_by_wg 
44from  sampo .pipeline .base  import  InputPipeline , SchedulePipeline 
55from  sampo .pipeline .delegating  import  DelegatingScheduler 
66from  sampo .pipeline .lag_optimization  import  LagOptimizationStrategy 
@@ -55,11 +55,11 @@ class DefaultInputPipeline(InputPipeline):
5555
5656    def  __init__ (self ):
5757        self ._wg : WorkGraph  |  pd .DataFrame  |  str  |  None  =  None 
58-         self ._contractors : list [Contractor ] |  pd .DataFrame  |  str  |  tuple [ContractorGenerationMethod , int ] |  None  \
59-             =  ContractorGenerationMethod .AVG , 1 
58+         self ._contractors : list [Contractor ] |  pd .DataFrame  |  str  |  tuple [ContractorGenerationMethod , int ,  int ] |  None  \
59+             =  ContractorGenerationMethod .AVG , 1 ,  1 
6060        self ._work_estimator : WorkTimeEstimator  =  DefaultWorkEstimator ()
6161        self ._node_orders : list [list [GraphNode ]] |  None  =  None 
62-         self ._lag_optimize : LagOptimizationStrategy  =  LagOptimizationStrategy .NONE 
62+         self ._lag_optimize : LagOptimizationStrategy  =  LagOptimizationStrategy .FALSE 
6363        self ._spec : ScheduleSpec  |  None  =  ScheduleSpec ()
6464        self ._assigned_parent_time : Time  |  None  =  Time (0 )
6565        self ._local_optimize_stack : ApplyQueue  =  ApplyQueue ()
@@ -98,7 +98,7 @@ def wg(self,
9898        self .sep_wg  =  sep 
9999        return  self 
100100
101-     def  contractors (self , contractors : list [Contractor ] |  pd .DataFrame  |  str  |  tuple [ContractorGenerationMethod , int ]) \
101+     def  contractors (self , contractors : list [Contractor ] |  pd .DataFrame  |  str  |  tuple [ContractorGenerationMethod , int ,  float ]) \
102102            ->  'InputPipeline' :
103103        """ 
104104        Mandatory argument. 
@@ -206,6 +206,15 @@ def schedule(self, scheduler: Scheduler, validate: bool = False) -> 'SchedulePip
206206
207207        check_and_correct_priorities (self ._wg )
208208
209+         if  not  isinstance (self ._contractors , list ):
210+             generation_method , contractors_number , scaler  =  self ._contractors 
211+             self ._contractors  =  [get_contractor_by_wg (self ._wg ,
212+                                                       method = generation_method ,
213+                                                       contractor_id = str (i ),
214+                                                       contractor_name = 'Contractor'  +  ' '  +  str (i  +  1 ),
215+                                                       scaler = scaler )
216+                            for  i  in  range (contractors_number )]
217+ 
209218        if  not  contractors_can_perform_work_graph (self ._contractors , self ._wg ):
210219            raise  NoSufficientContractorError ('Contractors are not able to perform the graph of works' )
211220
@@ -236,17 +245,6 @@ def prioritization(head_nodes: list[GraphNode],
236245            print ('Trying to apply local optimizations to non-generic scheduler, ignoring it' )
237246
238247        match  self ._lag_optimize :
239-             case  LagOptimizationStrategy .NONE :
240-                 wg  =  self ._wg 
241-                 schedules  =  scheduler .schedule_with_cache (wg , self ._contractors ,
242-                                                           self ._spec ,
243-                                                           landscape = self ._landscape_config ,
244-                                                           assigned_parent_time = self ._assigned_parent_time ,
245-                                                           validate = validate )
246-                 node_orders  =  [node_order  for  _ , _ , _ , node_order  in  schedules ]
247-                 schedules  =  [schedule  for  schedule , _ , _ , _  in  schedules ]
248-                 self ._node_orders  =  node_orders 
249- 
250248            case  LagOptimizationStrategy .AUTO :
251249                # Searching the best 
252250                wg1  =  graph_restructuring (self ._wg , False )
@@ -278,7 +276,7 @@ def prioritization(head_nodes: list[GraphNode],
278276                    wg  =  wg2 
279277                    schedules  =  schedules2 
280278
281-             case  _ :
279+             case  LagOptimizationStrategy . TRUE ,  LagOptimizationStrategy . FALSE :
282280                wg  =  graph_restructuring (self ._wg , self ._lag_optimize .value )
283281                schedules  =  scheduler .schedule_with_cache (wg , self ._contractors ,
284282                                                          self ._spec ,
@@ -288,6 +286,16 @@ def prioritization(head_nodes: list[GraphNode],
288286                node_orders  =  [node_order  for  _ , _ , _ , node_order  in  schedules ]
289287                schedules  =  [schedule  for  schedule , _ , _ , _  in  schedules ]
290288                self ._node_orders  =  node_orders 
289+             case  _:
290+                 wg  =  self ._wg 
291+                 schedules  =  scheduler .schedule_with_cache (wg , self ._contractors ,
292+                                                           self ._spec ,
293+                                                           landscape = self ._landscape_config ,
294+                                                           assigned_parent_time = self ._assigned_parent_time ,
295+                                                           validate = validate )
296+                 node_orders  =  [node_order  for  _ , _ , _ , node_order  in  schedules ]
297+                 schedules  =  [schedule  for  schedule , _ , _ , _  in  schedules ]
298+                 self ._node_orders  =  node_orders 
291299
292300        return  DefaultSchedulePipeline (self , wg , schedules )
293301
0 commit comments