1
1
package io.hasura.ndc.sqlgen
2
2
3
+ import io.hasura.ndc.common.ConnectorConfiguration
4
+ import io.hasura.ndc.common.NDCScalar
3
5
import io.hasura.ndc.ir.*
4
6
import org.jooq.Condition
5
7
import org.jooq.impl.DSL
8
+ import org.jooq.impl.SQLDataType
6
9
import org.jooq.Field
7
10
8
11
sealed interface BaseGenerator {
@@ -28,7 +31,8 @@ sealed interface BaseGenerator {
28
31
fun buildComparison (
29
32
col : Field <Any >,
30
33
operator : ApplyBinaryComparisonOperator ,
31
- listVal : List <Field <Any >>
34
+ listVal : List <Field <Any >>,
35
+ columnType : NDCScalar ?
32
36
): Condition {
33
37
if (operator != ApplyBinaryComparisonOperator .IN && listVal.size != 1 ) {
34
38
error(" Only the IN operator supports multiple values" )
@@ -37,20 +41,67 @@ sealed interface BaseGenerator {
37
41
// unwrap single value for use in all but the IN operator
38
42
// OR return falseCondition if listVal is empty
39
43
val singleVal = listVal.firstOrNull() ? : return DSL .falseCondition()
44
+ val castedValue = castValue(singleVal, columnType)
40
45
41
46
return when (operator ) {
42
- ApplyBinaryComparisonOperator .EQ -> col.eq(singleVal )
43
- ApplyBinaryComparisonOperator .GT -> col.gt(singleVal )
44
- ApplyBinaryComparisonOperator .GTE -> col.ge(singleVal )
45
- ApplyBinaryComparisonOperator .LT -> col.lt(singleVal )
46
- ApplyBinaryComparisonOperator .LTE -> col.le(singleVal )
47
- ApplyBinaryComparisonOperator .IN -> col.`in `(listVal)
47
+ ApplyBinaryComparisonOperator .EQ -> col.eq(castedValue )
48
+ ApplyBinaryComparisonOperator .GT -> col.gt(castedValue )
49
+ ApplyBinaryComparisonOperator .GTE -> col.ge(castedValue )
50
+ ApplyBinaryComparisonOperator .LT -> col.lt(castedValue )
51
+ ApplyBinaryComparisonOperator .LTE -> col.le(castedValue )
52
+ ApplyBinaryComparisonOperator .IN -> col.`in `(listVal.map { castValue(it, columnType) } )
48
53
ApplyBinaryComparisonOperator .IS_NULL -> col.isNull
49
54
ApplyBinaryComparisonOperator .LIKE -> col.like(singleVal as Field <String >)
50
55
ApplyBinaryComparisonOperator .CONTAINS -> col.contains(singleVal as Field <String >)
51
56
}
52
57
}
53
58
59
+ fun castValue (value : Any , scalarType : NDCScalar ? ): Any {
60
+ return when (scalarType) {
61
+ NDCScalar .TIMESTAMPTZ -> DSL .cast(value, SQLDataType .TIMESTAMPWITHTIMEZONE )
62
+ NDCScalar .TIMESTAMP -> DSL .cast(value, SQLDataType .TIMESTAMP )
63
+ NDCScalar .DATE -> DSL .cast(value, SQLDataType .DATE )
64
+ else -> value
65
+ }
66
+ }
67
+
68
+ fun getColumnType (col : ComparisonTarget , request : QueryRequest ): NDCScalar ? {
69
+ val connectorConfig = ConnectorConfiguration .Loader .config
70
+ val collection = getCollectionForCompCol(col, request)
71
+ val collectionIsTable = connectorConfig.tables.any { it.tableName == collection }
72
+ val collectionIsNativeQuery = connectorConfig.nativeQueries.containsKey(collection)
73
+ val columnType = when {
74
+ collectionIsTable -> {
75
+ val table = connectorConfig.tables.find { it.tableName == collection }
76
+ ? : error(" Table $collection not found in connector configuration" )
77
+
78
+ val column = table.columns.find { it.name == col.name }
79
+ ? : error(" Column ${col.name} not found in table $collection " )
80
+
81
+ column.type
82
+ }
83
+
84
+ collectionIsNativeQuery -> {
85
+ val nativeQuery = connectorConfig.nativeQueries[collection]
86
+ ? : error(" Native query $collection not found in connector configuration" )
87
+
88
+ val column = nativeQuery.columns[col.name]
89
+ ? : error(" Column ${col.name} not found in native query $collection " )
90
+
91
+ Type .extractBaseType(column)
92
+ }
93
+
94
+ else -> error(" Collection $collection not found in connector configuration" )
95
+ }
96
+
97
+ return when {
98
+ columnType == " DATE" -> NDCScalar .DATE
99
+ columnType.contains(" TIMESTAMP" ) && ! columnType.contains(" TIME ZONE" ) -> NDCScalar .TIMESTAMP
100
+ columnType.contains(" TIMESTAMP" ) && columnType.contains(" TIME ZONE" ) -> NDCScalar .TIMESTAMPTZ
101
+ else -> null
102
+ }
103
+ }
104
+
54
105
private fun getCollectionForCompCol (
55
106
col : ComparisonTarget ,
56
107
request : QueryRequest
@@ -124,6 +175,7 @@ sealed interface BaseGenerator {
124
175
}
125
176
126
177
is Expression .ApplyBinaryComparison -> {
178
+ val columnType = getColumnType(e.column, request)
127
179
val column = DSL .field(
128
180
DSL .name(
129
181
splitCollectionName(getCollectionForCompCol(e.column, request)) + e.column.name
@@ -143,7 +195,7 @@ sealed interface BaseGenerator {
143
195
144
196
is ComparisonValue .VariableComp -> listOf (DSL .field(DSL .name(listOf (" vars" , v.name))))
145
197
}
146
- return buildComparison(column, e.operator , comparisonValue)
198
+ return buildComparison(column, e.operator , comparisonValue, columnType )
147
199
}
148
200
149
201
is Expression .ApplyUnaryComparison -> {
0 commit comments