@@ -22,54 +22,74 @@ public class JetMathTranslator : IMethodCallTranslator
22
22
23
23
private static readonly Dictionary < MethodInfo , string > _supportedMethodTranslationsDirect = new Dictionary < MethodInfo , string >
24
24
{
25
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( decimal ) } ) , "ABS" } ,
26
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( double ) } ) , "ABS" } ,
27
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( float ) } ) , "ABS" } ,
28
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( int ) } ) , "ABS" } ,
29
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( long ) } ) , "ABS" } ,
30
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( sbyte ) } ) , "ABS" } ,
31
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( short ) } ) , "ABS" } ,
32
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Pow ) , new [ ] { typeof ( double ) , typeof ( double ) } ) , "POW" } , // This is handled by JetQuerySqlGenerator
33
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Exp ) , new [ ] { typeof ( double ) } ) , "EXP" } ,
34
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Log ) , new [ ] { typeof ( double ) } ) , "LOG" } ,
35
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sqrt ) , new [ ] { typeof ( double ) } ) , "SQR" } ,
36
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Atan ) , new [ ] { typeof ( double ) } ) , "ATN" } ,
37
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Cos ) , new [ ] { typeof ( double ) } ) , "COS" } ,
38
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sin ) , new [ ] { typeof ( double ) } ) , "SIN" } ,
39
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Tan ) , new [ ] { typeof ( double ) } ) , "TAN" } ,
40
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( decimal ) } ) , "SGN" } ,
41
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( double ) } ) , "SGN" } ,
42
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( float ) } ) , "SGN" } ,
43
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( int ) } ) , "SGN" } ,
44
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( long ) } ) , "SGN" } ,
45
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( sbyte ) } ) , "SGN" } ,
46
- { typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( short ) } ) , "SGN" }
25
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( decimal ) } ) , "ABS" } ,
26
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( double ) } ) , "ABS" } ,
27
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( float ) } ) , "ABS" } ,
28
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( int ) } ) , "ABS" } ,
29
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( long ) } ) , "ABS" } ,
30
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( sbyte ) } ) , "ABS" } ,
31
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Abs ) , new [ ] { typeof ( short ) } ) , "ABS" } ,
32
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Pow ) , new [ ] { typeof ( double ) , typeof ( double ) } ) , "POW" } , // This is handled by JetQuerySqlGenerator
33
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Exp ) , new [ ] { typeof ( double ) } ) , "EXP" } ,
34
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Log ) , new [ ] { typeof ( double ) } ) , "LOG" } ,
35
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sqrt ) , new [ ] { typeof ( double ) } ) , "SQR" } ,
36
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Atan ) , new [ ] { typeof ( double ) } ) , "ATN" } ,
37
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Cos ) , new [ ] { typeof ( double ) } ) , "COS" } ,
38
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sin ) , new [ ] { typeof ( double ) } ) , "SIN" } ,
39
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Tan ) , new [ ] { typeof ( double ) } ) , "TAN" } ,
40
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( decimal ) } ) , "SGN" } ,
41
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( double ) } ) , "SGN" } ,
42
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( float ) } ) , "SGN" } ,
43
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( int ) } ) , "SGN" } ,
44
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( long ) } ) , "SGN" } ,
45
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( sbyte ) } ) , "SGN" } ,
46
+ { typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Sign ) , new [ ] { typeof ( short ) } ) , "SGN" } ,
47
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Abs ) , typeof ( float ) ) , "ABS" } ,
48
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Pow ) , typeof ( float ) , typeof ( float ) ) , "POW" } ,
49
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Exp ) , typeof ( float ) ) , "EXP" } ,
50
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Log ) , typeof ( float ) ) , "LOG" } ,
51
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Sqrt ) , typeof ( float ) ) , "SQR" } ,
52
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Atan ) , typeof ( float ) ) , "ATN" } ,
53
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Cos ) , typeof ( float ) ) , "COS" } ,
54
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Sin ) , typeof ( float ) ) , "SIN" } ,
55
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Tan ) , typeof ( float ) ) , "TAN" } ,
56
+ { typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Sign ) , typeof ( float ) ) , "SGN" }
47
57
} ;
48
58
49
59
private static readonly MethodInfo [ ] _supportedMethodTranslationsIndirect = {
50
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Acos ) , new [ ] { typeof ( double ) } ) ,
51
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Asin ) , new [ ] { typeof ( double ) } ) ,
52
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Atan2 ) , new [ ] { typeof ( double ) , typeof ( double ) } ) ,
53
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Floor ) , new [ ] { typeof ( decimal ) } ) ,
54
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Floor ) , new [ ] { typeof ( double ) } ) ,
55
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Ceiling ) , new [ ] { typeof ( decimal ) } ) ,
56
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Ceiling ) , new [ ] { typeof ( double ) } ) ,
57
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Log10 ) , new [ ] { typeof ( double ) } ) ,
58
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Log ) , new [ ] { typeof ( double ) , typeof ( double ) } )
60
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Acos ) , new [ ] { typeof ( double ) } ) ,
61
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Asin ) , new [ ] { typeof ( double ) } ) ,
62
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Atan2 ) , new [ ] { typeof ( double ) , typeof ( double ) } ) ,
63
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Floor ) , new [ ] { typeof ( decimal ) } ) ,
64
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Floor ) , new [ ] { typeof ( double ) } ) ,
65
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Ceiling ) , new [ ] { typeof ( decimal ) } ) ,
66
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Ceiling ) , new [ ] { typeof ( double ) } ) ,
67
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Log10 ) , new [ ] { typeof ( double ) } ) ,
68
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Log ) , new [ ] { typeof ( double ) , typeof ( double ) } ) ,
69
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Ceiling ) , typeof ( float ) ) ,
70
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Floor ) , typeof ( float ) ) ,
71
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Log10 ) , typeof ( float ) ) ,
72
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Log ) , typeof ( float ) , typeof ( float ) ) ,
73
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Acos ) , typeof ( float ) ) ,
74
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Asin ) , typeof ( float ) ) ,
75
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Atan2 ) , typeof ( float ) , typeof ( float ) ) ,
59
76
} ;
60
77
61
78
private static readonly IEnumerable < MethodInfo > _truncateMethodInfos = new [ ]
62
79
{
63
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Truncate ) , new [ ] { typeof ( decimal ) } ) ,
64
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Truncate ) , new [ ] { typeof ( double ) } )
80
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Truncate ) , new [ ] { typeof ( decimal ) } ) ,
81
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Truncate ) , new [ ] { typeof ( double ) } ) ,
82
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Truncate ) , typeof ( float ) )
65
83
} ;
66
84
67
85
private static readonly IEnumerable < MethodInfo > _roundMethodInfos = new [ ]
68
86
{
69
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( decimal ) } ) ,
70
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( double ) } ) ,
71
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( decimal ) , typeof ( int ) } ) ,
72
- typeof ( Math ) . GetRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( double ) , typeof ( int ) } )
87
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( decimal ) } ) ,
88
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( double ) } ) ,
89
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( decimal ) , typeof ( int ) } ) ,
90
+ typeof ( Math ) . GetRequiredRuntimeMethod ( nameof ( Math . Round ) , new [ ] { typeof ( double ) , typeof ( int ) } ) ,
91
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Round ) , typeof ( float ) ) ,
92
+ typeof ( MathF ) . GetRequiredRuntimeMethod ( nameof ( MathF . Round ) , typeof ( float ) , typeof ( int ) )
73
93
} ;
74
94
75
95
public JetMathTranslator ( ISqlExpressionFactory sqlExpressionFactory )
@@ -81,12 +101,25 @@ public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadO
81
101
82
102
if ( _supportedMethodTranslationsDirect . TryGetValue ( method , out var sqlFunctionName ) )
83
103
{
104
+ var typeMapping = arguments . Count == 1
105
+ ? ExpressionExtensions . InferTypeMapping ( arguments [ 0 ] )
106
+ : ExpressionExtensions . InferTypeMapping ( arguments [ 0 ] , arguments [ 1 ] ) ;
107
+
108
+ var newArguments = new SqlExpression [ arguments . Count ] ;
109
+ newArguments [ 0 ] = _sqlExpressionFactory . ApplyTypeMapping ( arguments [ 0 ] , typeMapping ) ;
110
+
111
+ if ( arguments . Count == 2 )
112
+ {
113
+ newArguments [ 1 ] = _sqlExpressionFactory . ApplyTypeMapping ( arguments [ 1 ] , typeMapping ) ;
114
+ }
115
+
84
116
return _sqlExpressionFactory . Function (
85
117
sqlFunctionName ,
86
- arguments ,
87
- true ,
88
- new [ ] { true } ,
89
- method . ReturnType ) ;
118
+ newArguments ,
119
+ nullable : true ,
120
+ argumentsPropagateNullability : newArguments . Select ( a => true ) . ToArray ( ) ,
121
+ method . ReturnType ,
122
+ sqlFunctionName == "SIGN" ? null : typeMapping ) ;
90
123
}
91
124
92
125
if ( _supportedMethodTranslationsIndirect . Contains ( method ) )
@@ -173,24 +206,38 @@ public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadO
173
206
174
207
if ( _truncateMethodInfos . Contains ( method ) )
175
208
{
176
- return _sqlExpressionFactory . Function (
209
+ var argument = arguments [ 0 ] ;
210
+ var result = ( SqlExpression ) _sqlExpressionFactory . Function (
177
211
"INT" ,
178
- new [ ] { arguments [ 0 ] } ,
212
+ new [ ] { argument } ,
179
213
true ,
180
214
new [ ] { true } ,
181
- method . ReturnType ) ;
215
+ typeof ( double ) ) ;
216
+ if ( argument . Type == typeof ( float ) )
217
+ {
218
+ result = _sqlExpressionFactory . Convert ( result , typeof ( float ) ) ;
219
+ }
220
+ return _sqlExpressionFactory . ApplyTypeMapping ( result , argument . TypeMapping ) ;
182
221
}
183
222
184
223
if ( _roundMethodInfos . Contains ( method ) )
185
224
{
186
- return _sqlExpressionFactory . Function (
225
+ var argument = arguments [ 0 ] ;
226
+ var digits = arguments . Count == 2 ? arguments [ 1 ] : _sqlExpressionFactory . Constant ( 0 ) ;
227
+ // Result of ROUND for float/double is always double in server side
228
+ var result = ( SqlExpression ) _sqlExpressionFactory . Function (
187
229
"ROUND" ,
188
- arguments . Count == 1
189
- ? new [ ] { arguments [ 0 ] , _sqlExpressionFactory . Constant ( 0 ) }
190
- : new [ ] { arguments [ 0 ] , arguments [ 1 ] } ,
191
- true ,
192
- new [ ] { true } ,
193
- method . ReturnType ) ;
230
+ new [ ] { argument , digits } ,
231
+ nullable : true ,
232
+ argumentsPropagateNullability : new [ ] { true , true } ,
233
+ typeof ( double ) ) ;
234
+
235
+ if ( argument . Type == typeof ( float ) )
236
+ {
237
+ result = _sqlExpressionFactory . Convert ( result , typeof ( float ) ) ;
238
+ }
239
+
240
+ return _sqlExpressionFactory . ApplyTypeMapping ( result , argument . TypeMapping ) ;
194
241
}
195
242
196
243
return null ;
0 commit comments