Skip to content

Commit 6dbe1e1

Browse files
committed
CSHARP-5563: Create AstExpression extension methods to make testing for constants easier.
1 parent 167afae commit 6dbe1e1

File tree

8 files changed

+132
-86
lines changed

8 files changed

+132
-86
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ public static AstExpression Convert(AstExpression input, AstExpression to, AstEx
260260
Ensure.IsNotNull(input, nameof(input));
261261
Ensure.IsNotNull(to, nameof(to));
262262

263-
if (to is AstConstantExpression toConstantExpression &&
264-
(toConstantExpression.Value as BsonString)?.Value is string toValue &&
265-
toValue != null &&
263+
if (to.IsStringConstant(out var toValue) &&
266264
onError == null &&
267265
onNull == null)
268266
{
@@ -365,26 +363,26 @@ public static AstExpression DerivativeOrIntegralWindowExpression(AstDerivativeOr
365363

366364
public static AstExpression Divide(AstExpression arg1, AstExpression arg2)
367365
{
368-
if (arg1 is AstConstantExpression constant1 && arg2 is AstConstantExpression constant2)
366+
if (arg1.IsConstant(out var constant1) && arg2.IsConstant(out var constant2))
369367
{
370368
return Divide(constant1, constant2);
371369
}
372370

373371
return new AstBinaryExpression(AstBinaryOperator.Divide, arg1, arg2);
374372

375-
static AstExpression Divide(AstConstantExpression constant1, AstConstantExpression constant2)
373+
static AstExpression Divide(BsonValue constant1, BsonValue constant2)
376374
{
377-
return (constant1.Value.BsonType, constant2.Value.BsonType) switch
375+
return (constant1.BsonType, constant2.BsonType) switch
378376
{
379-
(BsonType.Double, BsonType.Double) => constant1.Value.AsDouble / constant2.Value.AsDouble,
380-
(BsonType.Double, BsonType.Int32) => constant1.Value.AsDouble / constant2.Value.AsInt32,
381-
(BsonType.Double, BsonType.Int64) => constant1.Value.AsDouble / constant2.Value.AsInt64,
382-
(BsonType.Int32, BsonType.Double) => constant1.Value.AsInt32 / constant2.Value.AsDouble,
383-
(BsonType.Int32, BsonType.Int32) => (double)constant1.Value.AsInt32 / constant2.Value.AsInt32,
384-
(BsonType.Int32, BsonType.Int64) => (double)constant1.Value.AsInt32 / constant2.Value.AsInt64,
385-
(BsonType.Int64, BsonType.Double) => constant1.Value.AsInt64 / constant2.Value.AsDouble,
386-
(BsonType.Int64, BsonType.Int32) => (double)constant1.Value.AsInt64 / constant2.Value.AsInt32,
387-
(BsonType.Int64, BsonType.Int64) => (double)constant1.Value.AsInt64 / constant2.Value.AsInt64,
377+
(BsonType.Double, BsonType.Double) => constant1.AsDouble / constant2.AsDouble,
378+
(BsonType.Double, BsonType.Int32) => constant1.AsDouble / constant2.AsInt32,
379+
(BsonType.Double, BsonType.Int64) => constant1.AsDouble / constant2.AsInt64,
380+
(BsonType.Int32, BsonType.Double) => constant1.AsInt32 / constant2.AsDouble,
381+
(BsonType.Int32, BsonType.Int32) => (double)constant1.AsInt32 / constant2.AsInt32,
382+
(BsonType.Int32, BsonType.Int64) => (double)constant1.AsInt32 / constant2.AsInt64,
383+
(BsonType.Int64, BsonType.Double) => constant1.AsInt64 / constant2.AsDouble,
384+
(BsonType.Int64, BsonType.Int32) => (double)constant1.AsInt64 / constant2.AsInt32,
385+
(BsonType.Int64, BsonType.Int64) => (double)constant1.AsInt64 / constant2.AsInt64,
388386
_ => new AstBinaryExpression(AstBinaryOperator.Divide, constant1, constant2)
389387
};
390388
}
@@ -819,9 +817,9 @@ public static AstExpression StrLenBytes(AstExpression arg)
819817

820818
public static AstExpression StrLenCP(AstExpression arg)
821819
{
822-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
820+
if (arg.IsStringConstant(out var stringConstant))
823821
{
824-
var value = constantExpression.Value.AsString.Length;
822+
var value = stringConstant.Length;
825823
return new AstConstantExpression(value);
826824
}
827825
return new AstUnaryExpression(AstUnaryOperator.StrLenCP, arg);
@@ -880,9 +878,9 @@ public static AstExpression Switch(IEnumerable<(AstExpression Case, AstExpressio
880878

881879
public static AstExpression ToLower(AstExpression arg)
882880
{
883-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
881+
if (arg.IsStringConstant(out var stringConstant))
884882
{
885-
var value = constantExpression.Value.AsString.ToLowerInvariant();
883+
var value = stringConstant.ToLowerInvariant();
886884
return new AstConstantExpression(value);
887885
}
888886

@@ -896,9 +894,9 @@ public static AstExpression ToString(AstExpression arg)
896894

897895
public static AstExpression ToUpper(AstExpression arg)
898896
{
899-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
897+
if (arg.IsStringConstant(out var stringConstant))
900898
{
901-
var value = constantExpression.Value.AsString.ToUpperInvariant();
899+
var value = stringConstant.ToUpperInvariant();
902900
return new AstConstantExpression(value);
903901
}
904902

@@ -975,7 +973,7 @@ public static AstExpression Zip(IEnumerable<AstExpression> inputs, bool? useLong
975973
// private static methods
976974
private static bool AllArgsAreConstantBools(AstExpression[] args, out List<bool> values)
977975
{
978-
if (args.All(arg => arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.Boolean))
976+
if (args.All(arg => arg.IsBooleanConstant()))
979977
{
980978
values = args.Select(arg => ((AstConstantExpression)arg).Value.AsBoolean).ToList();
981979
return true;
@@ -987,7 +985,7 @@ private static bool AllArgsAreConstantBools(AstExpression[] args, out List<bool>
987985

988986
private static bool AllArgsAreConstantInt32s(AstExpression[] args, out List<int> values)
989987
{
990-
if (args.All(arg => arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.Int32))
988+
if (args.All(arg => arg.IsInt32Constant()))
991989
{
992990
values = args.Select(arg => ((AstConstantExpression)arg).Value.AsInt32).ToList();
993991
return true;

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,100 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
1919
{
2020
internal static class AstExpressionExtensions
2121
{
22+
public static bool IsBooleanConstant(this AstExpression expression)
23+
=>
24+
expression is AstConstantExpression constantExpression &&
25+
constantExpression.Value.IsBoolean;
26+
27+
public static bool IsBooleanConstant(this AstExpression expression, out bool value)
28+
{
29+
if (expression is AstConstantExpression constantExpression && constantExpression.Value is BsonBoolean bsonBoolean)
30+
{
31+
value = bsonBoolean.Value;
32+
return true;
33+
}
34+
35+
value = default;
36+
return false;
37+
}
38+
39+
public static bool IsBsonNull(this AstExpression expression)
40+
=>
41+
expression is AstConstantExpression constantExpression &&
42+
constantExpression.Value.IsBsonNull;
43+
44+
public static bool IsConstant(this AstExpression expression, out BsonValue value)
45+
{
46+
if (expression is AstConstantExpression constantExpression)
47+
{
48+
value = constantExpression.Value;
49+
return true;
50+
}
51+
52+
value = null;
53+
return false;
54+
}
55+
56+
public static bool IsConstant<TBsonValue>(this AstExpression expression, out TBsonValue value)
57+
where TBsonValue : BsonValue
58+
{
59+
if (expression is AstConstantExpression constantExpression && constantExpression.Value is TBsonValue bsonValue)
60+
{
61+
value = bsonValue;
62+
return true;
63+
}
64+
65+
value = null;
66+
return false;
67+
}
68+
69+
public static bool IsInt32Constant(this AstExpression expression)
70+
=>
71+
expression is AstConstantExpression constantExpression &&
72+
constantExpression.Value.IsInt32;
73+
74+
public static bool IsInt32Constant(this AstExpression expression, int value)
75+
=>
76+
expression is AstConstantExpression constantExpression &&
77+
constantExpression.Value is BsonInt32 bsonInt32 &&
78+
bsonInt32.Value == value;
79+
2280
public static bool IsInt32Constant(this AstExpression expression, out int value)
2381
{
24-
if (expression is AstConstantExpression constantExpression &&
25-
constantExpression.Value is BsonInt32 bsonInt32)
82+
if (expression is AstConstantExpression constantExpression && constantExpression.Value is BsonInt32 bsonInt32)
2683
{
2784
value = bsonInt32.Value;
2885
return true;
2986
}
3087

3188
value = default;
3289
return false;
33-
}
90+
}
3491

3592
public static bool IsMaxInt32(this AstExpression expression)
3693
=> expression.IsInt32Constant(out var value) && value == int.MaxValue;
3794

3895
public static bool IsRootVar(this AstExpression expression)
3996
=> expression is AstVarExpression varExpression && varExpression.Name == "ROOT" && varExpression.IsCurrent;
4097

98+
public static bool IsStringConstant(this AstExpression expression, string value)
99+
=>
100+
expression is AstConstantExpression constantExpression &&
101+
constantExpression.Value is BsonString bsonString &&
102+
bsonString.Value == value;
103+
104+
public static bool IsStringConstant(this AstExpression expression, out string value)
105+
{
106+
if (expression is AstConstantExpression constantExpression && constantExpression.Value is BsonString bsonString)
107+
{
108+
value = bsonString.Value;
109+
return true;
110+
}
111+
112+
value = default;
113+
return false;
114+
}
115+
41116
public static bool IsZero(this AstExpression expression)
42117
=> expression is AstConstantExpression constantExpression && constantExpression.Value == 0;
43118
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstGetFieldExpression.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ public override string ConvertToFieldPath()
6363

6464
public bool HasSafeFieldName(out string fieldName)
6565
{
66-
if (_fieldName is AstConstantExpression constantFieldName &&
67-
constantFieldName.Value is BsonString stringfieldName)
66+
if (_fieldName.IsStringConstant(out var constantFieldName))
6867
{
69-
fieldName = stringfieldName.Value;
68+
fieldName = constantFieldName;
7069
if (fieldName.Length > 0 && fieldName[0] != '$' && !fieldName.Contains('.'))
7170
{
7271
return true;

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@ public override AstNode VisitFilterField(AstFilterField node)
352352

353353
public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
354354
{
355-
if (node.FieldName is AstConstantExpression constantFieldName &&
356-
constantFieldName.Value.IsString &&
357-
constantFieldName.Value.AsString == "_elements")
355+
if (node.FieldName.IsStringConstant("_elements"))
358356
{
359357
throw new UnableToRemoveReferenceToElementsException();
360358
}
@@ -366,9 +364,7 @@ public override AstNode VisitMapExpression(AstMapExpression node)
366364
{
367365
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0"
368366
if (node.Input is AstGetFieldExpression mapInputGetFieldExpression &&
369-
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
370-
mapInputconstantFieldExpression.Value.IsString &&
371-
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
367+
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
372368
mapInputGetFieldExpression.Input.IsRootVar())
373369
{
374370
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
@@ -386,9 +382,7 @@ public override AstNode VisitPickExpression(AstPickExpression node)
386382
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
387383
if (node.Source is AstGetFieldExpression getFieldExpression &&
388384
getFieldExpression.Input.IsRootVar() &&
389-
getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
390-
constantFieldNameExpression.Value.IsString &&
391-
constantFieldNameExpression.Value.AsString == "_elements")
385+
getFieldExpression.FieldName.IsStringConstant("_elements"))
392386
{
393387
var @operator = node.Operator.ToAccumulatorOperator();
394388
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
@@ -425,9 +419,7 @@ bool TryOptimizeSizeOfElements(out AstExpression optimizedExpression)
425419
if (node.Operator == AstUnaryOperator.Size)
426420
{
427421
if (node.Arg is AstGetFieldExpression argGetFieldExpression &&
428-
argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
429-
constantFieldNameExpression.Value.IsString &&
430-
constantFieldNameExpression.Value.AsString == "_elements")
422+
argGetFieldExpression.FieldName.IsStringConstant("_elements"))
431423
{
432424
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
433425
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
@@ -445,9 +437,7 @@ bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression)
445437
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
446438
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
447439
node.Arg is AstGetFieldExpression getFieldExpression &&
448-
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
449-
getFieldConstantFieldNameExpression.Value.IsString &&
450-
getFieldConstantFieldNameExpression.Value == "_elements" &&
440+
getFieldExpression.FieldName.IsStringConstant("_elements") &&
451441
getFieldExpression.Input.IsRootVar())
452442
{
453443
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
@@ -467,9 +457,7 @@ bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpressio
467457
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
468458
node.Arg is AstMapExpression mapExpression &&
469459
mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
470-
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
471-
mapInputconstantFieldExpression.Value.IsString &&
472-
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
460+
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
473461
mapInputGetFieldExpression.Input.IsRootVar())
474462
{
475463
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));

0 commit comments

Comments
 (0)