1
1
import pytest
2
-
3
2
from mindsdb_sql_parser import parse_sql
4
3
from mindsdb_sql_parser .ast import *
5
4
from mindsdb_sql_parser .ast .mindsdb .evaluate import Evaluate
6
5
from mindsdb_sql_parser .lexer import MindsDBLexer
7
-
6
+ from mindsdb_sql_parser . utils import to_single_line
8
7
9
8
class TestEvaluate :
10
9
def test_evaluate_lexer (self ):
11
- sql = "EVALUATE balanced_accuracy_score FROM (SELECT true , pred FROM table_1)"
10
+ sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth , pred FROM table_1)"
12
11
tokens = list (MindsDBLexer ().tokenize (sql ))
13
12
assert tokens [0 ].type == 'EVALUATE'
14
13
assert tokens [1 ].type == 'ID'
15
14
assert tokens [1 ].value == 'balanced_accuracy_score'
16
15
17
16
def test_evaluate_full_1 (self ):
18
- sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth, pred FROM table_1) USING adjusted=1, param2=2;" # noqa
17
+ sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth, pred FROM table_1) USING adjusted=1, param2=2;"
19
18
ast = parse_sql (sql )
20
19
expected_ast = Evaluate (
21
20
name = Identifier ('balanced_accuracy_score' ),
22
21
query_str = "SELECT ground_truth, pred FROM table_1" ,
23
22
using = {'adjusted' : 1 , 'param2' : 2 },
24
23
)
25
- assert ' ' .join (str (ast ).split ()).lower () == sql .lower ()
24
+ assert to_single_line (str (ast )).lower () == to_single_line (sql ).lower ()
25
+ assert to_single_line (str (ast )).lower () == to_single_line (str (expected_ast )).lower ()
26
26
assert ast .to_tree () == expected_ast .to_tree ()
27
- assert str (ast ) == str (expected_ast )
28
27
29
28
def test_evaluate_full_2 (self ):
30
- query_str = """SELECT t.rental_price as ground_truth, m.rental_price as prediction FROM example_db.demo_data.home_rentals as t JOIN mindsdb.home_rentals_model as m limit 100""" # noqa
29
+ query_str = """SELECT t.rental_price as ground_truth, m.rental_price as prediction FROM example_db.demo_data.home_rentals as t JOIN mindsdb.home_rentals_model as m limit 100"""
31
30
sql = f"""EVALUATE r2_score FROM ({ query_str } );"""
32
31
ast = parse_sql (sql )
33
32
expected_ast = Evaluate (
34
33
name = Identifier ('r2_score' ),
35
34
query_str = query_str ,
36
35
)
37
- assert ' ' . join (str (ast ). split ()) .lower () == sql .lower ()
38
- assert ast . to_tree () == expected_ast . to_tree ()
39
- assert str ( ast ). lower () == str ( expected_ast ). lower ()
36
+ assert to_single_line (str (ast )) .lower () == to_single_line ( sql ) .lower ()
37
+ assert to_single_line ( str ( ast )). lower () == to_single_line ( str ( expected_ast )). lower ()
38
+ assert ast . to_tree () == expected_ast . to_tree ()
0 commit comments