Skip to content

Commit 53ff737

Browse files
authored
Support translating some string functions when they have a char parameters (#34999)
Fixes #32482
1 parent 49da213 commit 53ff737

File tree

9 files changed

+770
-51
lines changed

9 files changed

+770
-51
lines changed

src/EFCore.Cosmos/Query/Internal/Translators/CosmosStringMethodTranslator.cs

+39-18
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,52 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
1313
/// </summary>
1414
public class CosmosStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator
1515
{
16-
private static readonly MethodInfo IndexOfMethodInfo
16+
private static readonly MethodInfo IndexOfMethodInfoString
1717
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!;
1818

19-
private static readonly MethodInfo IndexOfMethodInfoWithStartingPosition
19+
private static readonly MethodInfo IndexOfMethodInfoChar
20+
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!;
21+
22+
private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionString
2023
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string), typeof(int)])!;
2124

22-
private static readonly MethodInfo ReplaceMethodInfo
25+
private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionChar
26+
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char), typeof(int)])!;
27+
28+
private static readonly MethodInfo ReplaceMethodInfoString
2329
= typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(string), typeof(string)])!;
2430

25-
private static readonly MethodInfo ContainsMethodInfo
31+
private static readonly MethodInfo ReplaceMethodInfoChar
32+
= typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(char), typeof(char)])!;
33+
34+
private static readonly MethodInfo ContainsMethodInfoString
2635
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;
2736

28-
private static readonly MethodInfo ContainsWithStringComparisonMethodInfo
37+
private static readonly MethodInfo ContainsMethodInfoChar
38+
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!;
39+
40+
private static readonly MethodInfo ContainsWithStringComparisonMethodInfoString
2941
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string), typeof(StringComparison)])!;
3042

31-
private static readonly MethodInfo StartsWithMethodInfo
43+
private static readonly MethodInfo ContainsWithStringComparisonMethodInfoChar
44+
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char), typeof(StringComparison)])!;
45+
46+
private static readonly MethodInfo StartsWithMethodInfoString
3247
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!;
3348

34-
private static readonly MethodInfo StartsWithWithStringComparisonMethodInfo
49+
private static readonly MethodInfo StartsWithMethodInfoChar
50+
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!;
51+
52+
private static readonly MethodInfo StartsWithWithStringComparisonMethodInfoString
3553
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string), typeof(StringComparison)])!;
3654

37-
private static readonly MethodInfo EndsWithMethodInfo
55+
private static readonly MethodInfo EndsWithMethodInfoString
3856
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!;
3957

40-
private static readonly MethodInfo EndsWithWithStringComparisonMethodInfo
58+
private static readonly MethodInfo EndsWithMethodInfoChar
59+
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!;
60+
61+
private static readonly MethodInfo EndsWithWithStringComparisonMethodInfoString
4162
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string), typeof(StringComparison)])!;
4263

4364
private static readonly MethodInfo ToLowerMethodInfo
@@ -109,27 +130,27 @@ private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentSta
109130
{
110131
if (instance != null)
111132
{
112-
if (IndexOfMethodInfo.Equals(method))
133+
if (IndexOfMethodInfoString.Equals(method) || IndexOfMethodInfoChar.Equals(method))
113134
{
114135
return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0]);
115136
}
116137

117-
if (IndexOfMethodInfoWithStartingPosition.Equals(method))
138+
if (IndexOfMethodInfoWithStartingPositionString.Equals(method) || IndexOfMethodInfoWithStartingPositionChar.Equals(method))
118139
{
119140
return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0], arguments[1]);
120141
}
121142

122-
if (ReplaceMethodInfo.Equals(method))
143+
if (ReplaceMethodInfoString.Equals(method) || ReplaceMethodInfoChar.Equals(method))
123144
{
124145
return TranslateSystemFunction("REPLACE", method.ReturnType, instance, arguments[0], arguments[1]);
125146
}
126147

127-
if (ContainsMethodInfo.Equals(method))
148+
if (ContainsMethodInfoString.Equals(method) || ContainsMethodInfoChar.Equals(method))
128149
{
129150
return TranslateSystemFunction("CONTAINS", typeof(bool), instance, arguments[0]);
130151
}
131152

132-
if (ContainsWithStringComparisonMethodInfo.Equals(method))
153+
if (ContainsWithStringComparisonMethodInfoString.Equals(method) || ContainsWithStringComparisonMethodInfoChar.Equals(method))
133154
{
134155
if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType })
135156
{
@@ -150,12 +171,12 @@ private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentSta
150171
return null;
151172
}
152173

153-
if (StartsWithMethodInfo.Equals(method))
174+
if (StartsWithMethodInfoString.Equals(method) || StartsWithMethodInfoChar.Equals(method))
154175
{
155176
return TranslateSystemFunction("STARTSWITH", typeof(bool), instance, arguments[0]);
156177
}
157178

158-
if (StartsWithWithStringComparisonMethodInfo.Equals(method))
179+
if (StartsWithWithStringComparisonMethodInfoString.Equals(method))
159180
{
160181
if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType })
161182
{
@@ -176,12 +197,12 @@ private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentSta
176197
return null;
177198
}
178199

179-
if (EndsWithMethodInfo.Equals(method))
200+
if (EndsWithMethodInfoString.Equals(method) || EndsWithMethodInfoChar.Equals(method))
180201
{
181202
return TranslateSystemFunction("ENDSWITH", typeof(bool), instance, arguments[0]);
182203
}
183204

184-
if (EndsWithWithStringComparisonMethodInfo.Equals(method))
205+
if (EndsWithWithStringComparisonMethodInfoString.Equals(method))
185206
{
186207
if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType })
187208
{

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

+57-6
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,24 @@ private static readonly HashSet<ExpressionType> ArithmeticOperatorTypes
5353
ExpressionType.Modulo
5454
];
5555

56-
private static readonly MethodInfo StringStartsWithMethodInfo
56+
private static readonly MethodInfo StringStartsWithMethodInfoString
5757
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!;
5858

59-
private static readonly MethodInfo StringEndsWithMethodInfo
59+
private static readonly MethodInfo StringStartsWithMethodInfoChar
60+
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!;
61+
62+
private static readonly MethodInfo StringEndsWithMethodInfoString
6063
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!;
6164

62-
private static readonly MethodInfo StringContainsMethodInfo
65+
private static readonly MethodInfo StringEndsWithMethodInfoChar
66+
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!;
67+
68+
private static readonly MethodInfo StringContainsMethodInfoString
6369
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;
6470

71+
private static readonly MethodInfo StringContainsMethodInfoChar
72+
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!;
73+
6574
private static readonly MethodInfo StringJoinMethodInfo
6675
= typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(string[])])!;
6776

@@ -187,21 +196,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
187196
methodCallExpression.Type);
188197
}
189198

190-
if (method == StringStartsWithMethodInfo
199+
if ((method == StringStartsWithMethodInfoString || method == StringStartsWithMethodInfoChar)
191200
&& TryTranslateStartsEndsWithContains(
192201
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1))
193202
{
194203
return translation1;
195204
}
196205

197-
if (method == StringEndsWithMethodInfo
206+
if ((method == StringEndsWithMethodInfoString || method == StringEndsWithMethodInfoChar)
198207
&& TryTranslateStartsEndsWithContains(
199208
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2))
200209
{
201210
return translation2;
202211
}
203212

204-
if (method == StringContainsMethodInfo
213+
if ((method == StringContainsMethodInfoString || method == StringContainsMethodInfoChar)
205214
&& TryTranslateStartsEndsWithContains(
206215
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3))
207216
{
@@ -328,6 +337,32 @@ string when _sqlServerSingletonOptions.EngineType is SqlServerEngineType.AzureSy
328337
}),
329338
_sqlExpressionFactory.Constant(LikeEscapeString)),
330339

340+
char s when !IsLikeWildChar(s)
341+
=> _sqlExpressionFactory.Like(
342+
translatedInstance,
343+
_sqlExpressionFactory.Constant(
344+
methodType switch
345+
{
346+
StartsEndsWithContains.StartsWith => s + "%",
347+
StartsEndsWithContains.EndsWith => "%" + s,
348+
StartsEndsWithContains.Contains => $"%{s}%",
349+
350+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
351+
})),
352+
353+
char s => _sqlExpressionFactory.Like(
354+
translatedInstance,
355+
_sqlExpressionFactory.Constant(
356+
methodType switch
357+
{
358+
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
359+
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
360+
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",
361+
362+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
363+
}),
364+
_sqlExpressionFactory.Constant(LikeEscapeString)),
365+
331366
_ => throw new UnreachableException()
332367
};
333368

@@ -463,6 +498,22 @@ SqlExpression CharIndexGreaterThanZero()
463498
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
464499
},
465500

501+
char s when !IsLikeWildChar(s) => methodType switch
502+
{
503+
StartsEndsWithContains.StartsWith => s + "%",
504+
StartsEndsWithContains.EndsWith => "%" + s,
505+
StartsEndsWithContains.Contains => $"%{s}%",
506+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
507+
},
508+
509+
char s => methodType switch
510+
{
511+
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
512+
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
513+
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",
514+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
515+
},
516+
466517
_ => throw new UnreachableException()
467518
};
468519

src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringMethodTranslator.cs

+20-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
55
using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;
6+
using CharTypeMapping = Microsoft.EntityFrameworkCore.Storage.CharTypeMapping;
67
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
78

89
// ReSharper disable once CheckNamespace
@@ -16,15 +17,24 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
1617
/// </summary>
1718
public class SqlServerStringMethodTranslator : IMethodCallTranslator
1819
{
19-
private static readonly MethodInfo IndexOfMethodInfo
20+
private static readonly MethodInfo IndexOfMethodInfoString
2021
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!;
2122

22-
private static readonly MethodInfo IndexOfMethodInfoWithStartingPosition
23+
private static readonly MethodInfo IndexOfMethodInfoChar
24+
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!;
25+
26+
private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionString
2327
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string), typeof(int)])!;
2428

25-
private static readonly MethodInfo ReplaceMethodInfo
29+
private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionChar
30+
= typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char), typeof(int)])!;
31+
32+
private static readonly MethodInfo ReplaceMethodInfoString
2633
= typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(string), typeof(string)])!;
2734

35+
private static readonly MethodInfo ReplaceMethodInfoChar
36+
= typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(char), typeof(char)])!;
37+
2838
private static readonly MethodInfo ToLowerMethodInfo
2939
= typeof(string).GetRuntimeMethod(nameof(string.ToLower), Type.EmptyTypes)!;
3040

@@ -115,25 +125,25 @@ public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
115125
{
116126
if (instance != null)
117127
{
118-
if (IndexOfMethodInfo.Equals(method))
128+
if (IndexOfMethodInfoString.Equals(method) || IndexOfMethodInfoChar.Equals(method))
119129
{
120130
return TranslateIndexOf(instance, method, arguments[0], null);
121131
}
122132

123-
if (IndexOfMethodInfoWithStartingPosition.Equals(method))
133+
if (IndexOfMethodInfoWithStartingPositionString.Equals(method) || IndexOfMethodInfoWithStartingPositionChar.Equals(method))
124134
{
125135
return TranslateIndexOf(instance, method, arguments[0], arguments[1]);
126136
}
127137

128-
if (ReplaceMethodInfo.Equals(method))
138+
if (ReplaceMethodInfoString.Equals(method) || ReplaceMethodInfoChar.Equals(method))
129139
{
130140
var firstArgument = arguments[0];
131141
var secondArgument = arguments[1];
132142
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, firstArgument, secondArgument);
133143

134144
instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping);
135-
firstArgument = _sqlExpressionFactory.ApplyTypeMapping(firstArgument, stringTypeMapping);
136-
secondArgument = _sqlExpressionFactory.ApplyTypeMapping(secondArgument, stringTypeMapping);
145+
firstArgument = _sqlExpressionFactory.ApplyTypeMapping(firstArgument, firstArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
146+
secondArgument = _sqlExpressionFactory.ApplyTypeMapping(secondArgument, secondArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
137147

138148
return _sqlExpressionFactory.Function(
139149
"REPLACE",
@@ -323,7 +333,8 @@ private SqlExpression TranslateIndexOf(
323333
SqlExpression? startIndex)
324334
{
325335
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, searchExpression)!;
326-
searchExpression = _sqlExpressionFactory.ApplyTypeMapping(searchExpression, stringTypeMapping);
336+
searchExpression = _sqlExpressionFactory.ApplyTypeMapping(searchExpression, searchExpression.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
337+
327338
instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping);
328339

329340
var charIndexArguments = new List<SqlExpression> { searchExpression, instance };

src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs

+23-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@ public class SqliteSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp
1919
private readonly QueryCompilationContext _queryCompilationContext;
2020
private readonly ISqlExpressionFactory _sqlExpressionFactory;
2121

22-
private static readonly MethodInfo StringStartsWithMethodInfo
22+
private static readonly MethodInfo StringStartsWithMethodInfoString
2323
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!;
2424

25-
private static readonly MethodInfo StringEndsWithMethodInfo
25+
private static readonly MethodInfo StringStartsWithMethodInfoChar
26+
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!;
27+
28+
private static readonly MethodInfo StringEndsWithMethodInfoString
2629
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!;
2730

31+
private static readonly MethodInfo StringEndsWithMethodInfoChar
32+
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!;
33+
2834
private static readonly MethodInfo EscapeLikePatternParameterMethod =
2935
typeof(SqliteSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!;
3036

@@ -255,14 +261,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
255261
{
256262
var method = methodCallExpression.Method;
257263

258-
if (method == StringStartsWithMethodInfo
264+
if ((method == StringStartsWithMethodInfoString || method == StringStartsWithMethodInfoChar)
259265
&& TryTranslateStartsEndsWith(
260266
methodCallExpression.Object!, methodCallExpression.Arguments[0], startsWith: true, out var translation1))
261267
{
262268
return translation1;
263269
}
264270

265-
if (method == StringEndsWithMethodInfo
271+
if ((method == StringEndsWithMethodInfoString || method == StringEndsWithMethodInfoChar)
266272
&& TryTranslateStartsEndsWith(
267273
methodCallExpression.Object!, methodCallExpression.Arguments[0], startsWith: false, out var translation2))
268274
{
@@ -316,6 +322,15 @@ bool TryTranslateStartsEndsWith(
316322
translatedInstance,
317323
_sqlExpressionFactory.Constant(startsWith ? s + '%' : '%' + s)),
318324

325+
char s => IsLikeWildChar(s)
326+
? _sqlExpressionFactory.Like(
327+
translatedInstance,
328+
_sqlExpressionFactory.Constant(startsWith ? LikeEscapeString + s + "%" : '%' + LikeEscapeString + s),
329+
_sqlExpressionFactory.Constant(LikeEscapeString))
330+
: _sqlExpressionFactory.Like(
331+
translatedInstance,
332+
_sqlExpressionFactory.Constant(startsWith ? s + "%" : "%" + s)),
333+
319334
_ => throw new UnreachableException()
320335
};
321336

@@ -443,6 +458,10 @@ bool TryTranslateStartsEndsWith(
443458

444459
string s => startsWith ? EscapeLikePattern(s) + '%' : '%' + EscapeLikePattern(s),
445460

461+
char s when IsLikeWildChar(s )=> startsWith ? LikeEscapeString + s + '%' : '%' + LikeEscapeString + s,
462+
463+
char s => startsWith ? s + "%" : "%" + s,
464+
446465
_ => throw new UnreachableException()
447466
};
448467

0 commit comments

Comments
 (0)