Skip to content

Commit 8ee0305

Browse files
authored
Tweaks to ValueComparer nullability (#24410)
* Make Snapshot accept/receive non-nullable (nulls are sanitized externally). * Make ValueComparer<T>.GetHashCode accept non-nullable object.
1 parent 6412d18 commit 8ee0305

15 files changed

+48
-79
lines changed

src/EFCore.Cosmos/ChangeTracking/Internal/ListComparer.cs

+2-7
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
7777
return hash.ToHashCode();
7878
}
7979

80-
private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
80+
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
8181
{
82-
if (source is null)
83-
{
84-
return null;
85-
}
86-
8782
if (readOnly)
8883
{
8984
return source;
@@ -92,7 +87,7 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
9287
var snapshot = new List<TElement>(((IReadOnlyList<TElement>)source).Count);
9388
foreach (var e in source)
9489
{
95-
snapshot.Add(elementComparer.Snapshot(e)!);
90+
snapshot.Add(e is null ? default! : elementComparer.Snapshot(e));
9691
}
9792

9893
return (TCollection)(object)snapshot;

src/EFCore.Cosmos/ChangeTracking/Internal/NullableListComparer.cs

+1-6
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
8989
return hash.ToHashCode();
9090
}
9191

92-
private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
92+
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
9393
{
94-
if (source is null)
95-
{
96-
return null;
97-
}
98-
9994
if (readOnly)
10095
{
10196
return source;

src/EFCore.Cosmos/ChangeTracking/Internal/NullableSingleDimensionalArrayComparer.cs

+1-6
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,8 @@ private static int GetHashCode(TElement?[] source, ValueComparer<TElement> eleme
8888
}
8989

9090
[return: NotNullIfNotNull("source")]
91-
private static TElement?[]? Snapshot(TElement?[]? source, ValueComparer<TElement> elementComparer)
91+
private static TElement?[] Snapshot(TElement?[] source, ValueComparer<TElement> elementComparer)
9292
{
93-
if (source is null)
94-
{
95-
return null;
96-
}
97-
9893
var snapshot = new TElement?[source.Length];
9994
for (var i = 0; i < source.Length; i++)
10095
{

src/EFCore.Cosmos/ChangeTracking/Internal/NullableStringDictionaryComparer.cs

+1-6
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
9595
return hash.ToHashCode();
9696
}
9797

98-
private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
98+
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
9999
{
100-
if (source is null)
101-
{
102-
return null;
103-
}
104-
105100
if (readOnly)
106101
{
107102
return source;

src/EFCore.Cosmos/ChangeTracking/Internal/SingleDimensionalArrayComparer.cs

+3-7
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,13 @@ private static int GetHashCode(TElement[] source, ValueComparer<TElement> elemen
7676
}
7777

7878
[return: NotNullIfNotNull("source")]
79-
private static TElement[]? Snapshot(TElement[]? source, ValueComparer<TElement> elementComparer)
79+
private static TElement[] Snapshot(TElement[] source, ValueComparer<TElement> elementComparer)
8080
{
81-
if (source is null)
82-
{
83-
return null;
84-
}
85-
8681
var snapshot = new TElement[source.Length];
8782
for (var i = 0; i < source.Length; i++)
8883
{
89-
snapshot[i] = elementComparer.Snapshot(source[i])!;
84+
var element = source[i];
85+
snapshot[i] = element is null ? default! : elementComparer.Snapshot(source[i]);
9086
}
9187
return snapshot;
9288
}

src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs

+2-7
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
7979
return hash.ToHashCode();
8080
}
8181

82-
private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
82+
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
8383
{
84-
if (source is null)
85-
{
86-
return null;
87-
}
88-
8984
if (readOnly)
9085
{
9186
return source;
@@ -94,7 +89,7 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
9489
var snapshot = new Dictionary<string, TElement>(((IReadOnlyDictionary<string, TElement>)source).Count);
9590
foreach (var e in source)
9691
{
97-
snapshot.Add(e.Key, elementComparer.Snapshot(e.Value)!);
92+
snapshot.Add(e.Key, e.Value is null ? default! : elementComparer.Snapshot(e.Value));
9893
}
9994

10095
return (TCollection)(object)snapshot;

src/EFCore.SqlServer/Storage/Internal/SqlServerTypeMappingSource.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private readonly SqlServerByteArrayTypeMapping _rowversion
5252
comparer: new ValueComparer<byte[]>(
5353
(v1, v2) => StructuralComparisons.StructuralEqualityComparer.Equals(v1, v2),
5454
v => StructuralComparisons.StructuralEqualityComparer.GetHashCode(v),
55-
v => v == null ? null : v.ToArray()),
55+
v => v.ToArray()),
5656
storeTypePostfix: StoreTypePostfix.None);
5757

5858
private readonly IntTypeMapping _int

src/EFCore/ChangeTracking/ArrayStructuralComparer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public ArrayStructuralComparer()
2424
: base(
2525
CreateDefaultEqualsExpression(),
2626
CreateDefaultHashCodeExpression(favorStructuralComparisons: true),
27-
v => v == null ? null : v.ToArray())
27+
v => v.ToArray())
2828
{
2929
}
3030
}

src/EFCore/ChangeTracking/GeometryValueComparer.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public GeometryValueComparer()
5858
right);
5959
}
6060

61-
private static Expression<Func<TGeometry?, TGeometry?>> GetSnapshotExpression()
61+
private static Expression<Func<TGeometry, TGeometry>> GetSnapshotExpression()
6262
{
6363
var instance = Expression.Parameter(typeof(TGeometry), "instance");
6464

@@ -71,7 +71,7 @@ public GeometryValueComparer()
7171
body = Expression.Convert(body, typeof(TGeometry));
7272
}
7373

74-
return Expression.Lambda<Func<TGeometry?, TGeometry?>>(body, instance);
74+
return Expression.Lambda<Func<TGeometry, TGeometry>>(body, instance);
7575
}
7676
}
7777
}

src/EFCore/ChangeTracking/Internal/SimplePrincipalKeyValueFactory.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public NoNullsCustomEqualityComparer(ValueComparer comparer)
148148
public bool Equals(TKey? x, TKey? y)
149149
=> _equals(x, y);
150150

151-
public int GetHashCode(TKey obj)
151+
public int GetHashCode([DisallowNull] TKey obj)
152152
=> _hashCode(obj);
153153
}
154154
}

src/EFCore/ChangeTracking/Internal/ValueComparerExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public NonNullNullableValueComparer(
5858
: base(
5959
(Expression<Func<T?, T?, bool>>)equalsExpression,
6060
(Expression<Func<T, int>>)hashCodeExpression,
61-
(Expression<Func<T?, T?>>)snapshotExpression)
61+
(Expression<Func<T, T>>)snapshotExpression)
6262
{
6363
}
6464
}

src/EFCore/ChangeTracking/ValueComparer.cs

+14-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Collections;
66
using System.Collections.Generic;
7+
using System.Diagnostics.CodeAnalysis;
78
using System.Linq;
89
using System.Linq.Expressions;
910
using System.Reflection;
@@ -106,6 +107,7 @@ protected ValueComparer(
106107
/// </summary>
107108
/// <param name="instance"> The instance. </param>
108109
/// <returns> The snapshot. </returns>
110+
[return: NotNullIfNotNull("instance")]
109111
public abstract object? Snapshot(object? instance);
110112

111113
/// <summary>
@@ -196,31 +198,31 @@ public virtual Expression ExtractSnapshotBody(Expression expression)
196198
/// <returns> The <see cref="ValueComparer{T}" />. </returns>
197199
public static ValueComparer CreateDefault(Type type, bool favorStructuralComparisons)
198200
{
199-
var nonNullabletype = type.UnwrapNullableType();
201+
var nonNullableType = type.UnwrapNullableType();
200202

201203
// The equality operator returns false for NaNs, but the Equals methods returns true
202-
if (nonNullabletype == typeof(double))
204+
if (nonNullableType == typeof(double))
203205
{
204206
return new DefaultDoubleValueComparer(favorStructuralComparisons);
205207
}
206208

207-
if (nonNullabletype == typeof(float))
209+
if (nonNullableType == typeof(float))
208210
{
209211
return new DefaultFloatValueComparer(favorStructuralComparisons);
210212
}
211213

212-
if (nonNullabletype == typeof(DateTimeOffset))
214+
if (nonNullableType == typeof(DateTimeOffset))
213215
{
214216
return new DefaultDateTimeOffsetValueComparer(favorStructuralComparisons);
215217
}
216218

217-
var comparerType = nonNullabletype.IsInteger()
218-
|| nonNullabletype == typeof(decimal)
219-
|| nonNullabletype == typeof(bool)
220-
|| nonNullabletype == typeof(string)
221-
|| nonNullabletype == typeof(DateTime)
222-
|| nonNullabletype == typeof(Guid)
223-
|| nonNullabletype == typeof(TimeSpan)
219+
var comparerType = nonNullableType.IsInteger()
220+
|| nonNullableType == typeof(decimal)
221+
|| nonNullableType == typeof(bool)
222+
|| nonNullableType == typeof(string)
223+
|| nonNullableType == typeof(DateTime)
224+
|| nonNullableType == typeof(Guid)
225+
|| nonNullableType == typeof(TimeSpan)
224226
? typeof(DefaultValueComparer<>)
225227
: typeof(ValueComparer<>);
226228

@@ -253,7 +255,7 @@ public override Expression ExtractSnapshotBody(Expression expression)
253255
public override object? Snapshot(object? instance)
254256
=> instance;
255257

256-
public override T? Snapshot(T? instance)
258+
public override T Snapshot(T instance)
257259
=> instance;
258260
}
259261

src/EFCore/ChangeTracking/ValueComparer`.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class ValueComparer<T> : ValueComparer, IEqualityComparer<T>
3333
{
3434
private Func<T?, T?, bool>? _equals;
3535
private Func<T, int>? _hashCode;
36-
private Func<T?, T?>? _snapshot;
36+
private Func<T, T>? _snapshot;
3737

3838
/// <summary>
3939
/// Creates a new <see cref="ValueComparer{T}" /> with a default comparison
@@ -82,7 +82,7 @@ public ValueComparer(
8282
public ValueComparer(
8383
Expression<Func<T?, T?, bool>> equalsExpression,
8484
Expression<Func<T, int>> hashCodeExpression,
85-
Expression<Func<T?, T?>> snapshotExpression)
85+
Expression<Func<T, T>> snapshotExpression)
8686
: base(equalsExpression, hashCodeExpression, snapshotExpression)
8787
{
8888
}
@@ -161,7 +161,7 @@ public ValueComparer(
161161
/// Creates an expression for creating a snapshot of a value.
162162
/// </summary>
163163
/// <returns> The snapshot expression. </returns>
164-
protected static Expression<Func<T?, T?>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
164+
protected static Expression<Func<T, T>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
165165
{
166166
if (!favorStructuralComparisons
167167
|| !typeof(T).IsArray)
@@ -178,7 +178,7 @@ public ValueComparer(
178178
// var destination = new T[length];
179179
// Array.Copy(source, destination, length);
180180
// return destination;
181-
return Expression.Lambda<Func<T?, T?>>(
181+
return Expression.Lambda<Func<T, T>>(
182182
Expression.Block(
183183
new[] { lengthVariable, destinationVariable },
184184
Expression.Assign(
@@ -257,8 +257,8 @@ public override bool Equals(object? left, object? right)
257257
/// </summary>
258258
/// <param name="instance"> The instance. </param>
259259
/// <returns> The hash code. </returns>
260-
public override int GetHashCode(object? instance)
261-
=> instance == null ? 0 : GetHashCode((T)instance);
260+
public override int GetHashCode(object instance)
261+
=> instance is null ? 0 : GetHashCode((T)instance);
262262

263263
/// <summary>
264264
/// Compares the two instances to determine if they are equal.
@@ -293,7 +293,7 @@ public virtual int GetHashCode(T instance)
293293
/// <param name="instance"> The instance. </param>
294294
/// <returns> The snapshot. </returns>
295295
public override object? Snapshot(object? instance)
296-
=> instance == null ? null : Snapshot((T?)instance);
296+
=> instance == null ? null : Snapshot((T)instance);
297297

298298
/// <summary>
299299
/// <para>
@@ -308,7 +308,7 @@ public virtual int GetHashCode(T instance)
308308
/// </summary>
309309
/// <param name="instance"> The instance. </param>
310310
/// <returns> The snapshot. </returns>
311-
public virtual T? Snapshot(T? instance)
311+
public virtual T Snapshot(T instance)
312312
=> NonCapturingLazyInitializer.EnsureInitialized(
313313
ref _snapshot, this, static c => c.SnapshotExpression.Compile())(instance);
314314

@@ -341,7 +341,7 @@ public override Type Type
341341
/// reference.
342342
/// </para>
343343
/// </summary>
344-
public new virtual Expression<Func<T?, T?>> SnapshotExpression
345-
=> (Expression<Func<T?, T?>>)base.SnapshotExpression;
344+
public new virtual Expression<Func<T, T>> SnapshotExpression
345+
=> (Expression<Func<T, T>>)base.SnapshotExpression;
346346
}
347347
}

test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs

+8-10
Original file line numberDiff line numberDiff line change
@@ -606,14 +606,13 @@ await Can_add_update_delete_with_collection(
606606
public async Task Can_add_update_delete_with_nested_collections()
607607
{
608608
await Can_add_update_delete_with_collection(
609-
new List<List<short>> { new List<short> { 1, 2 } },
609+
new List<List<short>> { new() { 1, 2 } },
610610
c =>
611611
{
612612
c.Collection.Clear();
613613
c.Collection.Add(new List<short> { 3 });
614614
},
615-
new List<List<short>> { new List<short> { 3 } });
616-
615+
new List<List<short>> { new() { 3 } });
617616
await Can_add_update_delete_with_collection<IList<byte?[]>>(
618617
new List<byte?[]>(),
619618
c =>
@@ -622,30 +621,29 @@ await Can_add_update_delete_with_collection(
622621
c.Collection.Add(null);
623622
},
624623
new List<byte?[]> { new byte?[] { 3, null }, null });
625-
626624
await Can_add_update_delete_with_collection<IReadOnlyList<Dictionary<string, string>>>(
627-
new Dictionary<string, string>[] { new Dictionary<string, string> { { "1", null } } },
625+
new Dictionary<string, string>[] { new() { { "1", null } } },
628626
c =>
629627
{
630628
var dictionary = c.Collection[0]["3"] = "2";
631629
},
632-
new List<Dictionary<string, string>> { new Dictionary<string, string> { { "1", null }, { "3", "2" } } });
630+
new List<Dictionary<string, string>> { new() { { "1", null }, { "3", "2" } } });
633631

634632
await Can_add_update_delete_with_collection(
635-
new List<float>[] { new List<float> { 1f }, new List<float> { 2 } },
633+
new List<float>[] { new() { 1f }, new() { 2 } },
636634
c =>
637635
{
638636
c.Collection[1][0] = 3f;
639637
},
640-
new List<float>[] { new List<float> { 1f }, new List<float> { 3f } });
638+
new List<float>[] { new() { 1f }, new() { 3f } });
641639

642640
await Can_add_update_delete_with_collection(
643-
new decimal?[][] { new decimal?[] { 1, null } },
641+
new[] { new decimal?[] { 1, null } },
644642
c =>
645643
{
646644
c.Collection[0][1] = 3;
647645
},
648-
new decimal?[][] { new decimal?[] { 1, 3 } });
646+
new[] { new decimal?[] { 1, 3 } });
649647

650648
await Can_add_update_delete_with_collection(
651649
new Dictionary<string, List<int>> { { "1", new List<int> { 1 } } },

test/EFCore.Tests/Storage/ValueComparerTest.cs

-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ private static ValueComparer CompareTest(Type type, object value1, object value2
8585
Assert.False(comparer.Equals(null, value2));
8686
Assert.True(comparer.Equals(null, null));
8787

88-
Assert.Equal(0, comparer.GetHashCode(null));
8988
Assert.Equal(hashCode ?? value1.GetHashCode(), comparer.GetHashCode(value1));
9089

9190
var keyComparer = (ValueComparer)Activator.CreateInstance(typeof(ValueComparer<>).MakeGenericType(type), new object[] { true });
@@ -102,7 +101,6 @@ private static ValueComparer CompareTest(Type type, object value1, object value2
102101
Assert.False(keyComparer.Equals(null, value2));
103102
Assert.True(keyComparer.Equals(null, null));
104103

105-
Assert.Equal(0, keyComparer.GetHashCode(null));
106104
Assert.Equal(hashCode ?? value1.GetHashCode(), keyComparer.GetHashCode(value1));
107105

108106
return comparer;

0 commit comments

Comments
 (0)