@@ -27,6 +27,8 @@ open System.Linq.Expressions
2727open System.Runtime .InteropServices
2828open System.Reflection
2929
30+ type private CompareDiscriminatorExpression < 'T , 'D > = Expression< Func< 'T, 'D, bool>>
31+
3032/// <summary>
3133/// Allows to specify discriminator comparison or discriminator getter
3234/// and a function that return discriminator value depending on entity type
@@ -59,7 +61,7 @@ open System.Reflection
5961/// </code></example>
6062[<Struct>]
6163type ObjectListFilterLinqOptions < 'T , 'D >
62- ([< Optional>] compareDiscriminator : Expression < Func < 'T, 'D, bool > > | null , [< Optional>] getDiscriminatorValue : ( Type -> 'D) | null ) =
64+ ([< Optional>] compareDiscriminator : CompareDiscriminatorExpression < 'T, 'D> | null , [< Optional>] getDiscriminatorValue : ( Type -> 'D) | null ) =
6365
6466 member _.CompareDiscriminator = compareDiscriminator |> ValueOption.ofObj
6567 member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj
@@ -74,7 +76,7 @@ type ObjectListFilterLinqOptions<'T, 'D>
7476
7577 new ( getDiscriminator : Expression < Func < 'T , 'D >>) =
7678 ObjectListFilterLinqOptions< 'T, 'D> ( ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null )
77- new ( compareDiscriminator : Expression < Func < 'T , 'D , bool > >) = ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator, null )
79+ new ( compareDiscriminator : CompareDiscriminatorExpression < 'T , 'D >) = ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator, null )
7880 new ( getDiscriminatorValue : Type -> 'D ) =
7981 ObjectListFilterLinqOptions< 'T, 'D> ( compareDiscriminator = null , getDiscriminatorValue = getDiscriminatorValue)
8082 new ( getDiscriminator : Expression < Func < 'T , 'D >>, getDiscriminatorValue : Type -> 'D ) =
@@ -227,6 +229,20 @@ module ObjectListFilter =
227229 let paramExpr = Expression.PropertyOrField ( param, f.FieldName)
228230 buildFilterExpr ( SourceExpression paramExpr) buildTypeDiscriminatorCheck f.Value
229231
232+ type private CompareDiscriminatorExpressionVisitor < 'T , 'D > (
233+ compareDiscriminator : CompareDiscriminatorExpression< 'T, 'D>,
234+ param : SourceExpression,
235+ value : obj
236+ ) =
237+ inherit ExpressionVisitor ()
238+ override _.VisitParameter ( node ) =
239+ if node = compareDiscriminator.Parameters.[ 0 ] then
240+ param.Value
241+ elif node = compareDiscriminator.Parameters.[ 1 ] then
242+ Expression.Constant( value) :> Expression
243+ else
244+ node :> Expression
245+
230246 let apply ( options : ObjectListFilterLinqOptions < 'T , 'D >) ( filter : ObjectListFilter ) ( query : IQueryable < 'T >) =
231247 // Helper for discriminator comparison
232248 let buildTypeDiscriminatorCheck ( param : SourceExpression ) ( t : Type ) =
@@ -239,13 +255,9 @@ module ObjectListFilter =
239255 Expression.Constant ( t.FullName)
240256 ) :> Expression
241257 | ValueSome discExpr, ValueNone ->
242- Expression.Invoke (
243- // Provided discriminator comparison
244- discExpr,
245- param,
246- // Default discriminator value gathered from type
247- Expression.Constant( t.FullName)
248- ) :> Expression
258+ // Replace parameters from the original expression with our new ones
259+ let replacer = CompareDiscriminatorExpressionVisitor ( discExpr, param, t.FullName)
260+ replacer.Visit discExpr.Body
249261 | ValueNone, ValueSome discValueFn ->
250262 let discriminatorValue = discValueFn t
251263 Expression.Equal (
@@ -256,13 +268,9 @@ module ObjectListFilter =
256268 ) :> Expression
257269 | ValueSome discExpr, ValueSome discValueFn ->
258270 let discriminatorValue = discValueFn t
259- Expression.Invoke (
260- // Provided discriminator comparison
261- discExpr,
262- param,
263- // Provided discriminator value gathered from type
264- Expression.Constant ( discriminatorValue)
265- )
271+ // Replace parameters from the original expression with our new ones
272+ let replacer = CompareDiscriminatorExpressionVisitor ( discExpr, param, discriminatorValue)
273+ replacer.Visit discExpr.Body
266274 let queryExpr =
267275 let param = Expression.Parameter ( typeof< 'T>, " x" )
268276 let body = buildFilterExpr ( SourceExpression param) buildTypeDiscriminatorCheck filter
0 commit comments