Skip to content

Commit 880391f

Browse files
rbrushcopybara-github
authored andcommitted
Add support for sorting by expressions in CQL.
This change adds support for sorting by expressions in CQL. This is done by adding a new model type, SortByExpression, which represents an expression that can be used to sort the results of a query. The parser is updated to parse sort by expressions, and the interpreter is updated to evaluate them. PiperOrigin-RevId: 678323803
1 parent 5915f54 commit 880391f

File tree

11 files changed

+353
-48
lines changed

11 files changed

+353
-48
lines changed

cql_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ func TestCQL(t *testing.T) {
6363
retriever: enginetests.BuildRetriever(t),
6464
wantResult: newOrFatal(t, result.List{Value: []result.Value{
6565
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
66+
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
6667
},
6768
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
6869
}),
@@ -84,6 +85,7 @@ func TestCQL(t *testing.T) {
8485
wantSourceValues: []result.Value{
8586
newOrFatal(t, result.List{Value: []result.Value{
8687
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
88+
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
8789
},
8890
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
8991
}),
@@ -302,6 +304,7 @@ func TestCQL_MultipleEvals(t *testing.T) {
302304
wantResult: newOrFatal(t, result.List{
303305
Value: []result.Value{
304306
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
307+
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
305308
},
306309
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
307310
}),
@@ -324,6 +327,7 @@ func TestCQL_MultipleEvals(t *testing.T) {
324327
newOrFatal(t, result.List{
325328
Value: []result.Value{
326329
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
330+
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
327331
},
328332
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
329333
}),

internal/reference/reference.go

+30
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ type Resolver[T any, F any] struct {
5656
// defined. Aliases live in the same namespace as definitions.
5757
aliases []map[aliasKey]T
5858

59+
// scopedStructs hold the struct that are currently in scope for evaluation. For instance,
60+
// an an expression like `[Encounter] O sort by start of period` places each encounter in scope,
61+
// for the sorting criteria, and `period` is resolved against that encounter struct.
62+
scopedStructs []T
63+
5964
// libs holds the qualified identifier of all named libraries that have been parsed.
6065
libs map[namedLibKey]struct{}
6166

@@ -405,6 +410,31 @@ func (r *Resolver[T, F]) ExitScope() {
405410
}
406411
}
407412

413+
// EnterStructScope starts a new scope for a struct.
414+
func (r *Resolver[T, F]) EnterStructScope(q T) {
415+
r.scopedStructs = append(r.scopedStructs, q)
416+
}
417+
418+
// ExitStructScope clears the current struct scope.
419+
func (r *Resolver[T, F]) ExitStructScope() {
420+
if len(r.scopedStructs) > 0 {
421+
r.scopedStructs = r.scopedStructs[:len(r.scopedStructs)-1]
422+
}
423+
}
424+
425+
// HasScopedStruct returns true if there is a struct in the current scope.
426+
func (r *Resolver[T, F]) HasScopedStruct() bool {
427+
return len(r.scopedStructs) > 0
428+
}
429+
430+
// ScopedStruct returns the current struct scope.
431+
func (r *Resolver[T, F]) ScopedStruct() (T, error) {
432+
if len(r.scopedStructs) == 0 {
433+
return zero[T](), fmt.Errorf("no scoped structs were set")
434+
}
435+
return r.scopedStructs[len(r.scopedStructs)-1], nil
436+
}
437+
408438
// Alias creates a new alias within the current scope. When EndScope is called all aliases in the
409439
// scope will be removed. Calling ResolveLocal with the same name will return the stored type t.
410440
// Names must be unique within the CQL library.

internal/reference/reference_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,57 @@ func TestParserAliasAndResolve(t *testing.T) {
541541
}
542542
}
543543

544+
func TestScopedStructs(t *testing.T) {
545+
// Test scoping and de-scoping of structs in context.
546+
r := NewResolver[result.Value, *model.FunctionDef]()
547+
548+
if r.HasScopedStruct() {
549+
t.Errorf("HasScopedStruct() got true, want false")
550+
}
551+
_, err := r.ScopedStruct()
552+
if err == nil {
553+
t.Errorf("ScopedStruct() with no scope expected error but got success")
554+
}
555+
556+
v1 := newOrFatal(1, t)
557+
r.EnterStructScope(v1)
558+
if !r.HasScopedStruct() {
559+
t.Errorf("HasScopedStruct() got false when struct was in scope")
560+
}
561+
562+
got, err := r.ScopedStruct()
563+
if err != nil {
564+
t.Fatalf("ScopedStruct() unexpected err: %v", err)
565+
}
566+
if diff := cmp.Diff(v1, got); diff != "" {
567+
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
568+
}
569+
570+
v2 := newOrFatal(2, t)
571+
r.EnterStructScope(v2)
572+
got, err = r.ScopedStruct()
573+
if err != nil {
574+
t.Fatalf("ScopedStruct() unexpected err: %v", err)
575+
}
576+
if diff := cmp.Diff(v2, got); diff != "" {
577+
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
578+
}
579+
580+
r.ExitStructScope()
581+
got, err = r.ScopedStruct()
582+
if err != nil {
583+
t.Fatalf("ScopedStruct() unexpected err: %v", err)
584+
}
585+
if diff := cmp.Diff(v1, got); diff != "" {
586+
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
587+
}
588+
589+
r.ExitStructScope()
590+
if r.HasScopedStruct() {
591+
t.Errorf("HasScopedStruct() got true when no struct should be in scope")
592+
}
593+
}
594+
544595
func TestResolveIncludedLibrary(t *testing.T) {
545596
// TEST SETUP - PREVIOUS PARSED LIBRARY
546597
//

interpreter/expressions.go

+20
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ func (i *interpreter) evalExpression(elem model.IExpression) (result.Value, erro
5959
return i.evalQueryLetRef(elem)
6060
case *model.AliasRef:
6161
return i.evalAliasRef(elem)
62+
case *model.IdentifierRef:
63+
return i.evalIdentifierRef(elem)
6264
case *model.CodeSystemRef:
6365
return i.evalCodeSystemRef(elem)
6466
case *model.ValuesetRef:
@@ -305,6 +307,24 @@ func (i *interpreter) evalAliasRef(a *model.AliasRef) (result.Value, error) {
305307
return i.refs.ResolveLocal(a.Name)
306308
}
307309

310+
func (i *interpreter) evalIdentifierRef(r *model.IdentifierRef) (result.Value, error) {
311+
obj, err := i.refs.ScopedStruct()
312+
if err != nil {
313+
return result.Value{}, err
314+
}
315+
316+
// Passing the static types here is likely unimportant, but we compute it for completeness.
317+
aType, err := i.modelInfo.PropertyTypeSpecifier(obj.RuntimeType(), r.Name)
318+
if err != nil {
319+
return result.Value{}, err
320+
}
321+
ap, err := i.valueProperty(obj, r.Name, aType)
322+
if err != nil {
323+
return result.Value{}, err
324+
}
325+
return ap, nil
326+
}
327+
308328
func (i *interpreter) evalOperandRef(a *model.OperandRef) (result.Value, error) {
309329
return i.refs.ResolveLocal(a.Name)
310330
}

interpreter/query.go

+41-32
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (i *interpreter) evalQuery(q *model.Query) (result.Value, error) {
121121
return result.Value{}, err
122122
}
123123
} else {
124-
i.sortByColumn(finalVals, q.Sort.ByItems)
124+
err := i.sortByColumnOrExpression(finalVals, q.Sort.ByItems)
125125
if err != nil {
126126
return result.Value{}, err
127127
}
@@ -490,49 +490,58 @@ func compareNumeralInt[t float64 | int64 | int32](left, right t) int {
490490
}
491491
}
492492

493-
func (i *interpreter) sortByColumn(objs []result.Value, sbis []model.ISortByItem) error {
494-
// Validate sort column types.
495-
for _, sortItems := range sbis {
496-
// TODO(b/316984809): Is this validation in advance necessary? What if other values (beyond
497-
// objs[0]) have a different runtime type for the property (e.g. if they're a choice type)?
498-
// Consider validating types inline during the sort instead.
499-
path := sortItems.(*model.SortByColumn).Path
500-
propertyType, err := i.modelInfo.PropertyTypeSpecifier(objs[0].RuntimeType(), path)
493+
func (i *interpreter) dateTimeOrError(v result.Value) (result.Value, error) {
494+
switch sr := v.GolangValue().(type) {
495+
case result.DateTime:
496+
return v, nil
497+
case result.Named:
498+
if sr.RuntimeType.Equal(&types.Named{TypeName: "FHIR.dateTime"}) {
499+
return i.protoProperty(sr, "value", types.DateTime)
500+
}
501+
}
502+
return result.Value{}, fmt.Errorf("sorting only currently supported on DateTime columns")
503+
}
504+
505+
// getSortValue returns the value to be used for the comparison-based sort. This
506+
// is typically a field or expression on the structure being sorted.
507+
func (i *interpreter) getSortValue(it model.ISortByItem, v result.Value) (result.Value, error) {
508+
var rv result.Value
509+
var err error
510+
switch iv := it.(type) {
511+
case *model.SortByColumn:
512+
// Passing the static types here is likely unimportant, but we compute it for completeness.
513+
t, err := i.modelInfo.PropertyTypeSpecifier(v.RuntimeType(), iv.Path)
501514
if err != nil {
502-
return err
515+
return result.Value{}, err
503516
}
504-
columnVal, err := i.valueProperty(objs[0], path, propertyType)
517+
rv, err = i.valueProperty(v, iv.Path, t)
505518
if err != nil {
506-
return err
519+
return result.Value{}, err
507520
}
508-
// Strictly only allow DateTimes for now.
509-
// TODO(b/316984809): add sorting support for other types.
510-
if !columnVal.RuntimeType().Equal(types.DateTime) {
511-
return fmt.Errorf("sort column of a query must evaluate to a date time, instead got %v", columnVal.RuntimeType())
521+
case *model.SortByExpression:
522+
i.refs.EnterStructScope(v)
523+
defer i.refs.ExitStructScope()
524+
rv, err = i.evalExpression(iv.SortExpression)
525+
if err != nil {
526+
return result.Value{}, err
512527
}
528+
default:
529+
return result.Value{}, fmt.Errorf("internal error - unsupported sort by item type: %T", iv)
513530
}
514531

532+
return i.dateTimeOrError(rv)
533+
}
534+
535+
func (i *interpreter) sortByColumnOrExpression(objs []result.Value, sbis []model.ISortByItem) error {
515536
var sortErr error = nil
516537
slices.SortFunc(objs[:], func(a, b result.Value) int {
517-
for _, sortItems := range sbis {
518-
sortCol := sortItems.(*model.SortByColumn)
519-
// Passing the static types here is likely unimportant, but we compute it for completeness.
520-
aType, err := i.modelInfo.PropertyTypeSpecifier(a.RuntimeType(), sortCol.Path)
521-
if err != nil {
522-
sortErr = err
523-
continue
524-
}
525-
ap, err := i.valueProperty(a, sortCol.Path, aType)
526-
if err != nil {
527-
sortErr = err
528-
continue
529-
}
530-
bType, err := i.modelInfo.PropertyTypeSpecifier(b.RuntimeType(), sortCol.Path)
538+
for _, sortItem := range sbis {
539+
ap, err := i.getSortValue(sortItem, a)
531540
if err != nil {
532541
sortErr = err
533542
continue
534543
}
535-
bp, err := i.valueProperty(b, sortCol.Path, bType)
544+
bp, err := i.getSortValue(sortItem, b)
536545
if err != nil {
537546
sortErr = err
538547
continue
@@ -544,7 +553,7 @@ func (i *interpreter) sortByColumn(objs []result.Value, sbis []model.ISortByItem
544553
// TODO(b/308012659): Implement dateTime comparison that doesn't take a precision.
545554
if av.Equal(bv) {
546555
continue
547-
} else if sortCol.SortByItem.Direction == model.DESCENDING {
556+
} else if sortItem.SortDirection() == model.DESCENDING {
548557
return bv.Compare(av)
549558
}
550559
return av.Compare(bv)

model/model.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ type ReturnClause struct {
483483
// Follows format outlined in https://cql.hl7.org/elm/schema/expression.xsd.
484484
type ISortByItem interface {
485485
IElement
486-
isSortByItem()
486+
SortDirection() SortDirection
487487
}
488488

489489
// SortByItem is the base abstract type for all query types.
@@ -492,20 +492,25 @@ type SortByItem struct {
492492
Direction SortDirection
493493
}
494494

495+
// SortDirection returns the direction of the sort, e.g. ASCENDING or DESCENDING.
496+
func (s *SortByItem) SortDirection() SortDirection { return s.Direction }
497+
495498
// SortByDirection enables sorting non-tuple values by direction
496499
type SortByDirection struct {
497500
*SortByItem
498501
}
499502

500-
func (c *SortByDirection) isSortByItem() {}
501-
502503
// SortByColumn enables sorting by a given column and direction.
503504
type SortByColumn struct {
504505
*SortByItem
505506
Path string
506507
}
507508

508-
func (c *SortByColumn) isSortByItem() {}
509+
// SortByExpression enables sorting by an expression and direction.
510+
type SortByExpression struct {
511+
*SortByItem
512+
SortExpression IExpression
513+
}
509514

510515
// AliasedSource is a query source with an alias.
511516
type AliasedSource struct {
@@ -1158,6 +1163,14 @@ type OperandRef struct {
11581163
Name string
11591164
}
11601165

1166+
// IdentifierRef defines a reference to an identifier within a defined scope, such as a sort by.
1167+
// This is distinct from other references since it not a defined name, but will typically reference
1168+
// a field for some structure in scope of a sort expression.
1169+
type IdentifierRef struct {
1170+
*Expression
1171+
Name string
1172+
}
1173+
11611174
// UNARY EXPRESSION GETNAME()
11621175

11631176
// GetName returns the name of the system operator.

parser/expressions.go

+21
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ func (v *visitor) VisitQuantityContext(ctx cql.IQuantityContext) (model.Quantity
412412
// visitor.
413413
func (v *visitor) VisitReferentialIdentifier(ctx cql.IReferentialIdentifierContext) model.IExpression {
414414
name := v.parseReferentialIdentifier(ctx)
415+
416+
if v.refs.HasScopedStruct() {
417+
sourceFn, err := v.refs.ScopedStruct()
418+
if err != nil {
419+
return v.badExpression(err.Error(), ctx)
420+
}
421+
422+
// If the query source has the expected property, return the identifier ref. Otherwise
423+
// fall through to the resolution logic below.
424+
source := sourceFn()
425+
elementType := source.GetResultType().(*types.List).ElementType
426+
427+
ptype, err := v.modelInfo.PropertyTypeSpecifier(elementType, name)
428+
if err == nil {
429+
return &model.IdentifierRef{
430+
Name: name,
431+
Expression: model.ResultType(ptype),
432+
}
433+
}
434+
}
435+
415436
if i := v.refs.ResolveInclude(name); i != nil {
416437
return v.badExpression(fmt.Sprintf("internal error - referential identifier %v is a local identifier to an included library", name), ctx)
417438
}

0 commit comments

Comments
 (0)