@@ -13,44 +13,49 @@ use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
13
13
use optd_datafusion_repr:: cost:: BaseTableStats ;
14
14
use optd_datafusion_repr:: DatafusionOptimizer ;
15
15
use regex:: Regex ;
16
+ use std:: collections:: HashSet ;
16
17
use std:: sync:: Arc ;
17
18
18
19
#[ global_allocator]
19
20
static GLOBAL : MiMalloc = MiMalloc ;
20
21
21
- use anyhow:: { Context , Result } ;
22
+ use anyhow:: { bail , Result } ;
22
23
use async_trait:: async_trait;
23
24
24
25
#[ derive( Default ) ]
25
26
pub struct DatafusionDBMS {
26
27
ctx : SessionContext ,
27
28
/// Context enabling datafusion's logical optimizer.
28
29
use_df_logical_ctx : SessionContext ,
30
+ /// Shared optd optimizer (for tweaking config)
31
+ optd_optimizer : Option < Arc < OptdQueryPlanner > > ,
29
32
}
30
33
31
34
impl DatafusionDBMS {
32
35
pub async fn new ( ) -> Result < Self > {
33
- let ctx = DatafusionDBMS :: new_session_ctx ( false , None ) . await ?;
34
- let use_df_logical_ctx =
36
+ let ( ctx, optd_optimizer ) = DatafusionDBMS :: new_session_ctx ( false , None ) . await ?;
37
+ let ( use_df_logical_ctx, _ ) =
35
38
DatafusionDBMS :: new_session_ctx ( true , Some ( ctx. state ( ) . catalog_list ( ) . clone ( ) ) ) . await ?;
36
39
Ok ( Self {
37
40
ctx,
38
41
use_df_logical_ctx,
42
+ optd_optimizer : Some ( optd_optimizer) ,
39
43
} )
40
44
}
41
45
42
46
/// Creates a new session context. If the `use_df_logical` flag is set, datafusion's logical optimizer will be used.
43
47
async fn new_session_ctx (
44
48
use_df_logical : bool ,
45
49
catalog : Option < Arc < dyn CatalogList > > ,
46
- ) -> Result < SessionContext > {
50
+ ) -> Result < ( SessionContext , Arc < OptdQueryPlanner > ) > {
47
51
let mut session_config = SessionConfig :: from_env ( ) ?. with_information_schema ( true ) ;
48
52
if !use_df_logical {
49
53
session_config. options_mut ( ) . optimizer . max_passes = 0 ;
50
54
}
51
55
52
56
let rn_config = RuntimeConfig :: new ( ) ;
53
57
let runtime_env = RuntimeEnv :: new ( rn_config. clone ( ) ) ?;
58
+ let optd_optimizer;
54
59
55
60
let ctx = {
56
61
let mut state = if let Some ( catalog) = catalog {
@@ -73,20 +78,63 @@ impl DatafusionDBMS {
73
78
}
74
79
state = state. with_physical_optimizer_rules ( vec ! [ ] ) ;
75
80
// use optd-bridge query planner
76
- state = state. with_query_planner ( Arc :: new ( OptdQueryPlanner :: new ( optimizer) ) ) ;
81
+ optd_optimizer = Arc :: new ( OptdQueryPlanner :: new ( optimizer) ) ;
82
+ state = state. with_query_planner ( optd_optimizer. clone ( ) ) ;
77
83
SessionContext :: new_with_state ( state)
78
84
} ;
79
85
ctx. refresh_catalogs ( ) . await ?;
80
- Ok ( ctx)
86
+ Ok ( ( ctx, optd_optimizer ) )
81
87
}
82
88
83
- pub async fn execute ( & self , sql : & str , use_df_logical : bool ) -> Result < Vec < Vec < String > > > {
89
+ pub ( crate ) async fn execute ( & self , sql : & str , flags : & TestFlags ) -> Result < Vec < Vec < String > > > {
90
+ {
91
+ let mut guard = self
92
+ . optd_optimizer
93
+ . as_ref ( )
94
+ . unwrap ( )
95
+ . optimizer
96
+ . lock ( )
97
+ . unwrap ( ) ;
98
+ let optimizer = guard. as_mut ( ) . unwrap ( ) . optd_optimizer_mut ( ) ;
99
+ if flags. disable_explore_limit {
100
+ optimizer. disable_explore_limit ( ) ;
101
+ } else {
102
+ optimizer. enable_explore_limit ( ) ;
103
+ }
104
+ let rules = optimizer. rules ( ) ;
105
+ if flags. enable_logical_rules . is_empty ( ) {
106
+ for r in 0 ..rules. len ( ) {
107
+ optimizer. enable_rule ( r) ;
108
+ }
109
+ } else {
110
+ for ( rule_id, rule) in rules. as_ref ( ) . iter ( ) . enumerate ( ) {
111
+ if rule. rule . is_impl_rule ( ) {
112
+ optimizer. enable_rule ( rule_id) ;
113
+ } else {
114
+ optimizer. disable_rule ( rule_id) ;
115
+ }
116
+ }
117
+ let mut rules_to_enable = flags
118
+ . enable_logical_rules
119
+ . iter ( )
120
+ . map ( |x| x. as_str ( ) )
121
+ . collect :: < HashSet < _ > > ( ) ;
122
+ for ( rule_id, rule) in rules. as_ref ( ) . iter ( ) . enumerate ( ) {
123
+ if rules_to_enable. remove ( rule. rule . name ( ) ) {
124
+ optimizer. enable_rule ( rule_id) ;
125
+ }
126
+ }
127
+ if !rules_to_enable. is_empty ( ) {
128
+ bail ! ( "Unknown logical rule: {:?}" , rules_to_enable) ;
129
+ }
130
+ }
131
+ }
84
132
let sql = unescape_input ( sql) ?;
85
133
let dialect = Box :: new ( GenericDialect ) ;
86
134
let statements = DFParser :: parse_sql_with_dialect ( & sql, dialect. as_ref ( ) ) ?;
87
135
let mut result = Vec :: new ( ) ;
88
136
for statement in statements {
89
- let df = if use_df_logical {
137
+ let df = if flags . enable_df_logical {
90
138
let plan = self
91
139
. use_df_logical_ctx
92
140
. state ( )
@@ -95,6 +143,7 @@ impl DatafusionDBMS {
95
143
self . use_df_logical_ctx . execute_logical_plan ( plan) . await ?
96
144
} else {
97
145
let plan = self . ctx . state ( ) . statement_to_plan ( statement) . await ?;
146
+
98
147
self . ctx . execute_logical_plan ( plan) . await ?
99
148
} ;
100
149
@@ -123,10 +172,12 @@ impl DatafusionDBMS {
123
172
}
124
173
125
174
/// Executes the `execute` task.
126
- async fn task_execute ( & mut self , r : & mut String , sql : & str , flags : & [ String ] ) -> Result < ( ) > {
175
+ async fn task_execute ( & mut self , r : & mut String , sql : & str , flags : & TestFlags ) -> Result < ( ) > {
127
176
use std:: fmt:: Write ;
128
- let use_df_logical = flags. contains ( & "use_df_logical" . to_string ( ) ) ;
129
- let result = self . execute ( sql, use_df_logical) . await ?;
177
+ if flags. verbose {
178
+ bail ! ( "Verbose flag is not supported for execute task" ) ;
179
+ }
180
+ let result = self . execute ( sql, flags) . await ?;
130
181
writeln ! ( r, "{}" , result. into_iter( ) . map( |x| x. join( " " ) ) . join( "\n " ) ) ?;
131
182
writeln ! ( r) ?;
132
183
Ok ( ( ) )
@@ -138,19 +189,18 @@ impl DatafusionDBMS {
138
189
r : & mut String ,
139
190
sql : & str ,
140
191
task : & str ,
141
- flags : & [ String ] ,
192
+ flags : & TestFlags ,
142
193
) -> Result < ( ) > {
143
194
use std:: fmt:: Write ;
144
195
145
- let use_df_logical = flags. contains ( & "use_df_logical" . to_string ( ) ) ;
146
- let verbose = flags. contains ( & "verbose" . to_string ( ) ) ;
196
+ let verbose = flags. verbose ;
147
197
let explain_sql = if verbose {
148
198
format ! ( "explain verbose {}" , & sql)
149
199
} else {
150
200
format ! ( "explain {}" , & sql)
151
201
} ;
152
- let result = self . execute ( & explain_sql, use_df_logical ) . await ?;
153
- let subtask_start_pos = task. find ( ':' ) . unwrap ( ) + 1 ;
202
+ let result = self . execute ( & explain_sql, flags ) . await ?;
203
+ let subtask_start_pos = task. rfind ( ':' ) . unwrap ( ) + 1 ;
154
204
for subtask in task[ subtask_start_pos..] . split ( ',' ) {
155
205
let subtask = subtask. trim ( ) ;
156
206
if subtask == "logical_datafusion" {
@@ -163,7 +213,7 @@ impl DatafusionDBMS {
163
213
. map( |x| & x[ 1 ] )
164
214
. unwrap( )
165
215
) ?;
166
- } else if subtask == "logical_optd_heuristic" {
216
+ } else if subtask == "logical_optd_heuristic" || subtask == "optimized_logical_optd" {
167
217
writeln ! (
168
218
r,
169
219
"{}" ,
@@ -225,6 +275,8 @@ impl DatafusionDBMS {
225
275
. map( |x| & x[ 1 ] )
226
276
. unwrap( )
227
277
) ?;
278
+ } else {
279
+ bail ! ( "Unknown subtask: {}" , subtask) ;
228
280
}
229
281
}
230
282
@@ -235,10 +287,8 @@ impl DatafusionDBMS {
235
287
#[ async_trait]
236
288
impl sqlplannertest:: PlannerTestRunner for DatafusionDBMS {
237
289
async fn run ( & mut self , test_case : & sqlplannertest:: ParsedTestCase ) -> Result < String > {
238
- for before in & test_case. before_sql {
239
- self . execute ( before, true )
240
- . await
241
- . context ( "before execution error" ) ?;
290
+ if !test_case. before_sql . is_empty ( ) {
291
+ panic ! ( "before is not supported in optd-sqlplannertest, always specify the task type to run" ) ;
242
292
}
243
293
244
294
let mut result = String :: new ( ) ;
@@ -259,18 +309,42 @@ lazy_static! {
259
309
static ref FLAGS_REGEX : Regex = Regex :: new( r"\[(.*)\]" ) . unwrap( ) ;
260
310
}
261
311
312
+ #[ derive( Default , Debug ) ]
313
+ struct TestFlags {
314
+ verbose : bool ,
315
+ enable_df_logical : bool ,
316
+ enable_logical_rules : Vec < String > ,
317
+ disable_explore_limit : bool ,
318
+ }
319
+
262
320
/// Extract the flags from a task. The flags are specified in square brackets.
263
321
/// For example, the flags for the task `explain[use_df_logical, verbose]` are `["use_df_logical", "verbose"]`.
264
- fn extract_flags ( task : & str ) -> Result < Vec < String > > {
322
+ fn extract_flags ( task : & str ) -> Result < TestFlags > {
265
323
if let Some ( captures) = FLAGS_REGEX . captures ( task) {
266
- Ok ( captures
324
+ let flags = captures
267
325
. get ( 1 )
268
326
. unwrap ( )
269
327
. as_str ( )
270
328
. split ( ',' )
271
329
. map ( |x| x. trim ( ) . to_string ( ) )
272
- . collect ( ) )
330
+ . collect_vec ( ) ;
331
+ let mut options = TestFlags :: default ( ) ;
332
+ for flag in flags {
333
+ if flag == "verbose" {
334
+ options. verbose = true ;
335
+ } else if flag == "use_df_logical" {
336
+ options. enable_df_logical = true ;
337
+ } else if flag. starts_with ( "logical_rules" ) {
338
+ options. enable_logical_rules =
339
+ flag. split ( '+' ) . skip ( 1 ) . map ( |x| x. to_string ( ) ) . collect ( ) ;
340
+ } else if flag == "disable_explore_limit" {
341
+ options. disable_explore_limit = true ;
342
+ } else {
343
+ bail ! ( "Unknown flag: {}" , flag) ;
344
+ }
345
+ }
346
+ Ok ( options)
273
347
} else {
274
- Ok ( vec ! [ ] )
348
+ Ok ( TestFlags :: default ( ) )
275
349
}
276
350
}
0 commit comments