@@ -42,22 +42,25 @@ pub struct OptimizeInputsTask {
42
42
expr_id : ExprId ,
43
43
continue_from : Option < ContinueTask > ,
44
44
pruning : bool ,
45
+ upper_bound : Option < f64 > ,
45
46
}
46
47
47
48
impl OptimizeInputsTask {
48
- pub fn new ( expr_id : ExprId , pruning : bool ) -> Self {
49
+ pub fn new ( expr_id : ExprId , pruning : bool , upper_bound : Option < f64 > ) -> Self {
49
50
Self {
50
51
expr_id,
51
52
continue_from : None ,
52
53
pruning,
54
+ upper_bound,
53
55
}
54
56
}
55
57
56
- fn continue_from ( & self , cont : ContinueTask , pruning : bool ) -> Self {
58
+ fn continue_from ( & self , cont : ContinueTask , pruning : bool , upper_bound : Option < f64 > ) -> Self {
57
59
Self {
58
60
expr_id : self . expr_id ,
59
61
continue_from : Some ( cont) ,
60
62
pruning,
63
+ upper_bound,
61
64
}
62
65
}
63
66
@@ -153,6 +156,19 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
153
156
154
157
trace ! ( event = "task_begin" , task = "optimize_inputs" , expr_id = %self . expr_id, continue_from = %ContinueTaskDisplay ( & self . continue_from) , total_children = %children_group_ids. len( ) ) ;
155
158
159
+ let upper_bound = if self . pruning {
160
+ if let Some ( upper_bound) = self . upper_bound {
161
+ Some ( upper_bound)
162
+ } else if let Some ( winner) = optimizer. get_group_info ( group_id) . winner . as_full_winner ( )
163
+ {
164
+ Some ( winner. total_weighted_cost )
165
+ } else {
166
+ None
167
+ }
168
+ } else {
169
+ None
170
+ } ;
171
+
156
172
if let Some ( ContinueTask {
157
173
next_group_idx,
158
174
return_from_optimize_group,
@@ -219,9 +235,9 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
219
235
winner_weighted_cost = %trace_fmt( & group_info. winner) ,
220
236
current_processing = %next_group_idx,
221
237
total_child_groups = %children_group_ids. len( ) ) ;
222
- if let Some ( winner ) = group_info . winner . as_full_winner ( ) {
238
+ if let Some ( upper_bound ) = upper_bound {
223
239
let cost_so_far = cost. weighted_cost ( & total_cost) ;
224
- if winner . total_weighted_cost <= cost_so_far {
240
+ if upper_bound <= cost_so_far {
225
241
trace ! ( event = "task_finish" , task = "optimize_inputs" , expr_id = %self . expr_id, result = "pruned" ) ;
226
242
return Ok ( vec ! [ ] ) ;
227
243
}
@@ -232,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
232
248
let child_group_id = children_group_ids[ next_group_idx] ;
233
249
let group_idx = next_group_idx;
234
250
let child_group_info = optimizer. get_group_info ( child_group_id) ;
235
- if ! child_group_info. winner . has_full_winner ( ) {
251
+ let Some ( child_winner ) = child_group_info. winner . as_full_winner ( ) else {
236
252
if !return_from_optimize_group {
237
253
trace ! ( event = "task_yield" , task = "optimize_inputs" , expr_id = %self . expr_id, group_idx = %group_idx, yield_to = "optimize_group" , optimize_group_id = %child_group_id) ;
238
254
return Ok ( vec ! [
@@ -242,22 +258,25 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
242
258
return_from_optimize_group: true ,
243
259
} ,
244
260
self . pruning,
261
+ upper_bound,
245
262
) ) as Box <dyn Task <T , M >>,
246
- Box :: new( OptimizeGroupTask :: new( child_group_id) ) as Box <dyn Task <T , M >>,
263
+ Box :: new( OptimizeGroupTask :: new( child_group_id, upper_bound) )
264
+ as Box <dyn Task <T , M >>,
247
265
] ) ;
248
266
} else {
249
267
self . update_winner_impossible ( optimizer) ;
250
268
trace ! ( event = "task_finish" , task = "optimize_inputs" , expr_id = %self . expr_id, result = "impossible" ) ;
251
269
return Ok ( vec ! [ ] ) ;
252
270
}
253
- }
271
+ } ;
254
272
trace ! ( event = "task_yield" , task = "optimize_inputs" , expr_id = %self . expr_id, group_idx = %group_idx, yield_to = "next_optimize_input" ) ;
255
273
Ok ( vec ! [ Box :: new( self . continue_from(
256
274
ContinueTask {
257
275
next_group_idx: group_idx + 1 ,
258
276
return_from_optimize_group: false ,
259
277
} ,
260
278
self . pruning,
279
+ upper_bound. map( |bound| bound - child_winner. total_weighted_cost) ,
261
280
) ) as Box <dyn Task <T , M >>] )
262
281
} else {
263
282
self . update_winner ( input_statistics_ref, operation_cost, total_cost, optimizer) ;
@@ -272,6 +291,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
272
291
return_from_optimize_group: false ,
273
292
} ,
274
293
self . pruning,
294
+ upper_bound,
275
295
) ) as Box <dyn Task <T , M >>] )
276
296
}
277
297
}
0 commit comments