@@ -2,15 +2,12 @@ package io.hasura.mysql
2
2
3
3
import io.hasura.ndc.common.ConnectorConfiguration
4
4
import io.hasura.ndc.common.NDCScalar
5
- import io.hasura.ndc.common.NativeQueryInfo
6
- import io.hasura.ndc.common.NativeQueryPart
7
5
import io.hasura.ndc.ir.*
8
6
import io.hasura.ndc.ir.Field.ColumnField
9
7
import io.hasura.ndc.ir.Field as IRField
10
8
import io.hasura.ndc.sqlgen.BaseQueryGenerator
11
9
import org.jooq.*
12
10
import org.jooq.Field
13
- import org.jooq.impl.CustomField
14
11
import org.jooq.impl.DSL
15
12
import org.jooq.impl.SQLDataType
16
13
@@ -21,7 +18,7 @@ object JsonQueryGenerator : BaseQueryGenerator() {
21
18
return DSL
22
19
.with (buildVarsCTE(request))
23
20
.select(
24
- jsonArrayAgg(
21
+ DSL . jsonArrayAgg(
25
22
buildJSONSelectionForQueryRequest(request)
26
23
)
27
24
)
@@ -34,12 +31,12 @@ object JsonQueryGenerator : BaseQueryGenerator() {
34
31
return queryRequestToSQLInternal(request)
35
32
}
36
33
37
- fun queryRequestToSQLInternal (
34
+ private fun queryRequestToSQLInternal (
38
35
request : QueryRequest ,
39
36
): SelectSelectStep <* > {
40
37
// JOOQ is smart enough to not generate CTEs if there are no native queries
41
38
return mkNativeQueryCTEs(request).select(
42
- jsonArrayAgg(
39
+ DSL . jsonArrayAgg(
43
40
buildJSONSelectionForQueryRequest(request)
44
41
)
45
42
)
@@ -53,7 +50,8 @@ object JsonQueryGenerator : BaseQueryGenerator() {
53
50
54
51
val baseSelection = DSL .select(
55
52
DSL .table(DSL .name(request.collection)).asterisk()
56
- ).from(
53
+ ).select(getSelectOrderFields(request))
54
+ .from(
57
55
if (request.query.predicate == null ) {
58
56
DSL .table(DSL .name(request.collection))
59
57
} else {
@@ -77,8 +75,12 @@ object JsonQueryGenerator : BaseQueryGenerator() {
77
75
)
78
76
)
79
77
}
78
+
80
79
}
81
80
).apply {
81
+ addJoinsRequiredForOrderByFields(this , request)
82
+ }
83
+ .apply {
82
84
if (request.query.predicate != null ) {
83
85
where(getWhereConditions(request))
84
86
}
@@ -116,7 +118,7 @@ object JsonQueryGenerator : BaseQueryGenerator() {
116
118
DSL .jsonEntry(
117
119
" rows" ,
118
120
DSL .select(
119
- jsonArrayAgg(
121
+ DSL . jsonArrayAgg(
120
122
DSL .jsonObject(
121
123
(request.query.fields ? : emptyMap()).map { (alias, field) ->
122
124
when (field) {
@@ -163,6 +165,8 @@ object JsonQueryGenerator : BaseQueryGenerator() {
163
165
}
164
166
}
165
167
)
168
+ ).orderBy(
169
+ getConcatOrderFields(request)
166
170
)
167
171
).from(
168
172
baseSelection
@@ -193,11 +197,7 @@ object JsonQueryGenerator : BaseQueryGenerator() {
193
197
)
194
198
}
195
199
196
- fun jsonArrayAgg (field : JSONObjectNullStep <* >) = CustomField .of(" mysql_json_arrayagg" , SQLDataType .JSON ) {
197
- it.visit(DSL .field(" json_arrayagg({0})" , field))
198
- }
199
-
200
- fun collectRequiredJoinTablesForWhereClause (
200
+ private fun collectRequiredJoinTablesForWhereClause (
201
201
where : Expression ,
202
202
collectionRelationships : Map <String , Relationship >,
203
203
previousTableName : String? = null
@@ -230,7 +230,7 @@ object JsonQueryGenerator : BaseQueryGenerator() {
230
230
}
231
231
}
232
232
233
- fun ndcScalarTypeToSQLDataType (scalarType : NDCScalar ): DataType <out Any > = when (scalarType) {
233
+ private fun ndcScalarTypeToSQLDataType (scalarType : NDCScalar ): DataType <out Any > = when (scalarType) {
234
234
NDCScalar .BOOLEAN -> SQLDataType .BOOLEAN
235
235
NDCScalar .INT -> SQLDataType .INTEGER
236
236
NDCScalar .FLOAT -> SQLDataType .FLOAT
@@ -292,4 +292,23 @@ object JsonQueryGenerator : BaseQueryGenerator() {
292
292
return collection.split(' .' ).last()
293
293
}
294
294
295
+ private const val ORDER_FIELD_SUFFIX = " _order_field"
296
+
297
+ private fun getSelectOrderFields (request : QueryRequest ) : List <Field <* >>{
298
+ val sortFields = translateIROrderByField(request, request.collection)
299
+ return sortFields.map { it.`$field`().`as `(it.name + ORDER_FIELD_SUFFIX ) }
300
+ }
301
+
302
+ private fun getConcatOrderFields (request : QueryRequest ) : List <SortField <* >>{
303
+ val sortFields = translateIROrderByField(request, request.collection)
304
+ return sortFields.map {
305
+ val field = DSL .field(DSL .name(it.name + ORDER_FIELD_SUFFIX ))
306
+ when (it.order) {
307
+ SortOrder .ASC -> field.asc().nullsLast()
308
+ SortOrder .DESC -> field.desc().nullsFirst()
309
+ else -> field.asc().nullsLast()
310
+ }
311
+ }
312
+ }
313
+
295
314
}
0 commit comments