@@ -20,12 +20,26 @@ use async_trait::async_trait;
2020#[ derive( Default ) ]
2121pub struct DatafusionDb {
2222 ctx : SessionContext ,
23+ /// Context enabling datafusion's logical optimizer.
24+ with_logical_ctx : SessionContext ,
2325}
2426
2527impl DatafusionDb {
2628 pub async fn new ( ) -> Result < Self > {
29+ let ctx = DatafusionDb :: new_session_ctx ( false ) . await ?;
30+ let with_logical_ctx = DatafusionDb :: new_session_ctx ( true ) . await ?;
31+ Ok ( Self {
32+ ctx,
33+ with_logical_ctx,
34+ } )
35+ }
36+
37+ /// Creates a new session context. If the `with_logical` flag is set, datafusion's logical optimizer will be used.
38+ async fn new_session_ctx ( with_logical : bool ) -> Result < SessionContext > {
2739 let mut session_config = SessionConfig :: from_env ( ) ?. with_information_schema ( true ) ;
28- session_config. options_mut ( ) . optimizer . max_passes = 0 ;
40+ if !with_logical {
41+ session_config. options_mut ( ) . optimizer . max_passes = 0 ;
42+ }
2943
3044 let rn_config = RuntimeConfig :: new ( ) ;
3145 let runtime_env = RuntimeEnv :: new ( rn_config. clone ( ) ) ?;
@@ -36,26 +50,37 @@ impl DatafusionDb {
3650 let optimizer = DatafusionOptimizer :: new_physical ( Box :: new ( DatafusionCatalog :: new (
3751 state. catalog_list ( ) ,
3852 ) ) ) ;
39- // clean up optimizer rules so that we can plug in our own optimizer
40- state = state. with_optimizer_rules ( vec ! [ ] ) ;
41- state = state. with_physical_optimizer_rules ( vec ! [ ] ) ;
53+ if !with_logical {
54+ // clean up optimizer rules so that we can plug in our own optimizer
55+ state = state. with_optimizer_rules ( vec ! [ ] ) ;
56+ state = state. with_physical_optimizer_rules ( vec ! [ ] ) ;
57+ }
4258 // use optd-bridge query planner
4359 state = state. with_query_planner ( Arc :: new ( OptdQueryPlanner :: new ( optimizer) ) ) ;
4460 SessionContext :: new_with_state ( state)
4561 } ;
4662 ctx. refresh_catalogs ( ) . await ?;
47- Ok ( Self { ctx } )
63+ Ok ( ctx)
4864 }
4965
50- async fn execute ( & self , sql : & str ) -> Result < Vec < Vec < String > > > {
66+ async fn execute ( & self , sql : & str , with_logical : bool ) -> Result < Vec < Vec < String > > > {
5167 let sql = unescape_input ( sql) ?;
5268 let dialect = Box :: new ( GenericDialect ) ;
5369 let statements = DFParser :: parse_sql_with_dialect ( & sql, dialect. as_ref ( ) ) ?;
5470 let mut result = Vec :: new ( ) ;
5571 for statement in statements {
56- let plan = self . ctx . state ( ) . statement_to_plan ( statement) . await ?;
72+ let df = if with_logical {
73+ let plan = self
74+ . with_logical_ctx
75+ . state ( )
76+ . statement_to_plan ( statement)
77+ . await ?;
78+ self . with_logical_ctx . execute_logical_plan ( plan) . await ?
79+ } else {
80+ let plan = self . ctx . state ( ) . statement_to_plan ( statement) . await ?;
81+ self . ctx . execute_logical_plan ( plan) . await ?
82+ } ;
5783
58- let df = self . ctx . execute_logical_plan ( plan) . await ?;
5984 let batches = df. collect ( ) . await ?;
6085
6186 let options = FormatOptions :: default ( ) ;
@@ -79,53 +104,125 @@ impl DatafusionDb {
79104 }
80105 Ok ( result)
81106 }
107+
108+ /// Executes the `execute` task.
109+ async fn task_execute ( & mut self , r : & mut String , sql : & str , with_logical : bool ) -> Result < ( ) > {
110+ use std:: fmt:: Write ;
111+ let result = self . execute ( & sql, with_logical) . await ?;
112+ writeln ! ( r, "{}" , result. into_iter( ) . map( |x| x. join( " " ) ) . join( "\n " ) ) ?;
113+ writeln ! ( r) ?;
114+ Ok ( ( ) )
115+ }
116+
117+ /// Executes the `explain` task.
118+ async fn task_explain (
119+ & mut self ,
120+ r : & mut String ,
121+ sql : & str ,
122+ task : & str ,
123+ with_logical : bool ,
124+ ) -> Result < ( ) > {
125+ use std:: fmt:: Write ;
126+
127+ let result = self
128+ . execute ( & format ! ( "explain {}" , & sql) , with_logical)
129+ . await ?;
130+ let subtask_start_pos = if with_logical {
131+ "explain_with_logical:" . len ( )
132+ } else {
133+ "explain:" . len ( )
134+ } ;
135+ for subtask in task[ subtask_start_pos..] . split ( "," ) {
136+ let subtask = subtask. trim ( ) ;
137+ if subtask == "logical_datafusion" {
138+ writeln ! (
139+ r,
140+ "{}" ,
141+ result
142+ . iter( )
143+ . find( |x| x[ 0 ] == "logical_plan after datafusion" )
144+ . map( |x| & x[ 1 ] )
145+ . unwrap( )
146+ ) ?;
147+ } else if subtask == "logical_optd" {
148+ writeln ! (
149+ r,
150+ "{}" ,
151+ result
152+ . iter( )
153+ . find( |x| x[ 0 ] == "logical_plan after optd" )
154+ . map( |x| & x[ 1 ] )
155+ . unwrap( )
156+ ) ?;
157+ } else if subtask == "physical_optd" {
158+ writeln ! (
159+ r,
160+ "{}" ,
161+ result
162+ . iter( )
163+ . find( |x| x[ 0 ] == "physical_plan after optd" )
164+ . map( |x| & x[ 1 ] )
165+ . unwrap( )
166+ ) ?;
167+ } else if subtask == "join_orders" {
168+ writeln ! (
169+ r,
170+ "{}" ,
171+ result
172+ . iter( )
173+ . find( |x| x[ 0 ] == "physical_plan after optd-all-join-orders" )
174+ . map( |x| & x[ 1 ] )
175+ . unwrap( )
176+ ) ?;
177+ writeln ! ( r) ?;
178+ } else if subtask == "logical_join_orders" {
179+ writeln ! (
180+ r,
181+ "{}" ,
182+ result
183+ . iter( )
184+ . find( |x| x[ 0 ] == "physical_plan after optd-all-logical-join-orders" )
185+ . map( |x| & x[ 1 ] )
186+ . unwrap( )
187+ ) ?;
188+ writeln ! ( r) ?;
189+ } else if subtask == "physical_datafusion" {
190+ writeln ! (
191+ r,
192+ "{}" ,
193+ result
194+ . iter( )
195+ . find( |x| x[ 0 ] == "physical_plan" )
196+ . map( |x| & x[ 1 ] )
197+ . unwrap( )
198+ ) ?;
199+ }
200+ }
201+
202+ Ok ( ( ) )
203+ }
82204}
83205
84206#[ async_trait]
85207impl sqlplannertest:: PlannerTestRunner for DatafusionDb {
86208 async fn run ( & mut self , test_case : & sqlplannertest:: ParsedTestCase ) -> Result < String > {
87209 for before in & test_case. before_sql {
88- self . execute ( before)
210+ self . execute ( before, true )
89211 . await
90212 . context ( "before execution error" ) ?;
91213 }
92214
93- use std:: fmt:: Write ;
94215 let mut result = String :: new ( ) ;
95216 let r = & mut result;
96217 for task in & test_case. tasks {
97218 if task == "execute" {
98- let result = self . execute ( & test_case. sql ) . await ?;
99- writeln ! ( r , "{}" , result . into_iter ( ) . map ( |x| x . join ( " " ) ) . join ( " \n " ) ) ? ;
100- writeln ! ( r ) ?;
219+ self . task_execute ( r , & test_case. sql , false ) . await ?;
220+ } else if task == "execute_with_logical" {
221+ self . task_execute ( r , & test_case . sql , true ) . await ?;
101222 } else if task. starts_with ( "explain:" ) {
102- let result = self . execute ( & format ! ( "explain {}" , test_case. sql) ) . await ?;
103- for subtask in task[ "explain:" . len ( ) ..] . split ( "," ) {
104- let subtask = subtask. trim ( ) ;
105- if subtask == "join_orders" {
106- writeln ! (
107- r,
108- "{}" ,
109- result
110- . iter( )
111- . find( |x| x[ 0 ] == "physical_plan after optd-all-join-orders" )
112- . map( |x| & x[ 1 ] )
113- . unwrap( )
114- ) ?;
115- writeln ! ( r) ?;
116- } else if subtask == "logical_join_orders" {
117- writeln ! (
118- r,
119- "{}" ,
120- result
121- . iter( )
122- . find( |x| x[ 0 ] == "physical_plan after optd-all-logical-join-orders" )
123- . map( |x| & x[ 1 ] )
124- . unwrap( )
125- ) ?;
126- writeln ! ( r) ?;
127- }
128- }
223+ self . task_explain ( r, & test_case. sql , task, false ) . await ?;
224+ } else if task. starts_with ( "explain_with_logical:" ) {
225+ self . task_explain ( r, & test_case. sql , task, true ) . await ?;
129226 }
130227 }
131228 Ok ( result)
0 commit comments