Skip to content

Commit 865365e

Browse files
authored
Fix parameter detection related to sequence selectors for Linq provider (#2467)
Fix #2465
1 parent c8812b8 commit 865365e

File tree

7 files changed

+257
-0
lines changed

7 files changed

+257
-0
lines changed

src/NHibernate.Test/Async/Linq/ParameterTests.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ public async Task UsingParameterInEvaluatableExpressionAsync()
142142
await (db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToListAsync());
143143
}
144144

145+
[Test]
146+
public async Task UsingParameterOnSelectorsAsync()
147+
{
148+
var user = new User() {Id = 1};
149+
await (db.Users.Where(o => o == user).ToListAsync());
150+
await (db.Users.FirstOrDefaultAsync(o => o == user));
151+
await (db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToListAsync());
152+
153+
var users = new[] {new User() {Id = 1}};
154+
await (db.Users.Where(o => users.Contains(o)).ToListAsync());
155+
await (db.Users.FirstOrDefaultAsync(o => users.Contains(o)));
156+
await (db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToListAsync());
157+
}
158+
145159
[Test]
146160
public async Task UsingNegateValueTypeParameterTwiceAsync()
147161
{
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Linq;
12+
using NUnit.Framework;
13+
using NHibernate.Linq;
14+
15+
namespace NHibernate.Test.NHSpecificTest.GH2465
16+
{
17+
using System.Threading.Tasks;
18+
[TestFixture]
19+
public class FixtureAsync : BugTestCase
20+
{
21+
protected override bool AppliesTo(Dialect.Dialect dialect)
22+
{
23+
return dialect.SupportsScalarSubSelects;
24+
}
25+
26+
protected override void OnSetUp()
27+
{
28+
using (var session = OpenSession())
29+
using (var transaction = session.BeginTransaction())
30+
{
31+
var applicant = new Entity {IdentityNames = {"name1", "name2"}};
32+
session.Save(applicant);
33+
34+
transaction.Commit();
35+
}
36+
}
37+
38+
protected override void OnTearDown()
39+
{
40+
using (var session = OpenSession())
41+
using (var transaction = session.BeginTransaction())
42+
{
43+
session.Delete("from System.Object");
44+
transaction.Commit();
45+
}
46+
}
47+
48+
[Test]
49+
public async Task ContainsInsideValueCollectionAsync()
50+
{
51+
using (var session = OpenSession())
52+
using (var transaction = session.BeginTransaction())
53+
{
54+
var identityNames = new[] {"name1", "x"};
55+
await (session
56+
.Query<Entity>()
57+
.Where(a => a.IdentityNames.Any(n => identityNames.Contains(n)))
58+
.ToListAsync());
59+
await (session
60+
.Query<Entity>()
61+
.Where(a => a.IdentityNames.All(n => identityNames.Contains(n)))
62+
.ToListAsync());
63+
await (session
64+
.Query<Entity>()
65+
.Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test")
66+
.ToListAsync());
67+
68+
await (transaction.CommitAsync());
69+
}
70+
}
71+
72+
[Test]
73+
public async Task EqualsInsideValueCollectionAsync()
74+
{
75+
using (var session = OpenSession())
76+
using (var transaction = session.BeginTransaction())
77+
{
78+
var value = "test";
79+
await (session
80+
.Query<Entity>()
81+
.Where(a => a.IdentityNames.Any(n => n == value))
82+
.ToListAsync());
83+
await (session
84+
.Query<Entity>()
85+
.Where(a => a.IdentityNames.Any(n => (string) n == value))
86+
.ToListAsync());
87+
await (session
88+
.Query<Entity>()
89+
.Where(a => a.IdentityNames.All(n => n == value))
90+
.ToListAsync());
91+
await (session
92+
.Query<Entity>()
93+
.Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test")
94+
.ToListAsync());
95+
96+
await (transaction.CommitAsync());
97+
}
98+
}
99+
}
100+
}

src/NHibernate.Test/Linq/ParameterTests.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ public void UsingParameterInEvaluatableExpression()
130130
db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToList();
131131
}
132132

133+
[Test]
134+
public void UsingParameterOnSelectors()
135+
{
136+
var user = new User() {Id = 1};
137+
db.Users.Where(o => o == user).ToList();
138+
db.Users.FirstOrDefault(o => o == user);
139+
db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToList();
140+
141+
var users = new[] {new User() {Id = 1}};
142+
db.Users.Where(o => users.Contains(o)).ToList();
143+
db.Users.FirstOrDefault(o => users.Contains(o));
144+
db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToList();
145+
}
146+
133147
[Test]
134148
public void ValidateMixingTwoParametersCacheKeys()
135149
{
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH2465
5+
{
6+
public class Entity
7+
{
8+
private readonly IList<string> _identityNames = new List<string>();
9+
10+
public virtual Guid Id { get; set; }
11+
12+
public virtual IList<string> IdentityNames => _identityNames;
13+
}
14+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using System.Linq;
2+
using NUnit.Framework;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH2465
5+
{
6+
[TestFixture]
7+
public class Fixture : BugTestCase
8+
{
9+
protected override bool AppliesTo(Dialect.Dialect dialect)
10+
{
11+
return dialect.SupportsScalarSubSelects;
12+
}
13+
14+
protected override void OnSetUp()
15+
{
16+
using (var session = OpenSession())
17+
using (var transaction = session.BeginTransaction())
18+
{
19+
var applicant = new Entity {IdentityNames = {"name1", "name2"}};
20+
session.Save(applicant);
21+
22+
transaction.Commit();
23+
}
24+
}
25+
26+
protected override void OnTearDown()
27+
{
28+
using (var session = OpenSession())
29+
using (var transaction = session.BeginTransaction())
30+
{
31+
session.Delete("from System.Object");
32+
transaction.Commit();
33+
}
34+
}
35+
36+
[Test]
37+
public void ContainsInsideValueCollection()
38+
{
39+
using (var session = OpenSession())
40+
using (var transaction = session.BeginTransaction())
41+
{
42+
var identityNames = new[] {"name1", "x"};
43+
session
44+
.Query<Entity>()
45+
.Where(a => a.IdentityNames.Any(n => identityNames.Contains(n)))
46+
.ToList();
47+
session
48+
.Query<Entity>()
49+
.Where(a => a.IdentityNames.All(n => identityNames.Contains(n)))
50+
.ToList();
51+
session
52+
.Query<Entity>()
53+
.Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test")
54+
.ToList();
55+
56+
transaction.Commit();
57+
}
58+
}
59+
60+
[Test]
61+
public void EqualsInsideValueCollection()
62+
{
63+
using (var session = OpenSession())
64+
using (var transaction = session.BeginTransaction())
65+
{
66+
var value = "test";
67+
session
68+
.Query<Entity>()
69+
.Where(a => a.IdentityNames.Any(n => n == value))
70+
.ToList();
71+
session
72+
.Query<Entity>()
73+
.Where(a => a.IdentityNames.Any(n => (string) n == value))
74+
.ToList();
75+
session
76+
.Query<Entity>()
77+
.Where(a => a.IdentityNames.All(n => n == value))
78+
.ToList();
79+
session
80+
.Query<Entity>()
81+
.Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test")
82+
.ToList();
83+
84+
transaction.Commit();
85+
}
86+
}
87+
}
88+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<hibernate-mapping xmlns="urn:nhibernate-mapping-2.2" assembly="NHibernate.Test" namespace="NHibernate.Test.NHSpecificTest.GH2465">
3+
<class name="Entity">
4+
<id name="Id" generator="guid.comb"/>
5+
<bag name="IdentityNames" access="nosetter.camelcase-underscore">
6+
<key column="EntityId"/>
7+
<element type="String">
8+
<column name="IdentityName"/>
9+
</element>
10+
</bag>
11+
</class>
12+
</hibernate-mapping>

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq.Expressions;
55
using NHibernate.Engine;
66
using NHibernate.Param;
7+
using NHibernate.Persister.Collection;
78
using NHibernate.Type;
89
using NHibernate.Util;
910
using Remotion.Linq;
@@ -110,6 +111,12 @@ internal static void SetParameterTypes(
110111
out _,
111112
out _))
112113
{
114+
if (type.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(memberExpression))
115+
{
116+
var collection = (IQueryableCollection) ((IAssociationType) type).GetAssociatedJoinable(sessionFactory);
117+
type = collection.ElementType;
118+
}
119+
113120
break;
114121
}
115122
}
@@ -137,6 +144,7 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
137144
new Dictionary<ConstantExpression, IType>();
138145
public readonly Dictionary<Expression, HashSet<Expression>> RelatedExpressions =
139146
new Dictionary<Expression, HashSet<Expression>>();
147+
public readonly HashSet<Expression> SequenceSelectorExpressions = new HashSet<Expression>();
140148

141149
public ConstantTypeLocatorVisitor(
142150
bool removeMappedAsCalls,
@@ -247,6 +255,13 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
247255
}
248256
else
249257
{
258+
// In case a parameter is related to a sequence selector we will have to get the underlying item type
259+
// (e.g. q.Where(o => o.Users.Any(u => u == user)))
260+
if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase))
261+
{
262+
SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector);
263+
}
264+
250265
node.QueryModel.TransformExpressions(Visit);
251266
}
252267

0 commit comments

Comments
 (0)