diff --git a/src/Storage.MongoDB/MongoNonceRepository.cs b/src/Storage.MongoDB/MongoNonceRepository.cs index 0a9cb46..0ced544 100644 --- a/src/Storage.MongoDB/MongoNonceRepository.cs +++ b/src/Storage.MongoDB/MongoNonceRepository.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using MongoDB.Bson.Serialization; @@ -53,19 +54,16 @@ public MongoNonceRepository(IMongoDatabase database, MongoNonceOptions options) public async ValueTask InsertOneAsync( Token token, CancellationToken cancellationToken) { - await _collection.InsertOneAsync(token, cancellationToken: cancellationToken); + token.ExtraProperties = token.ExtraProperties ?? new Dictionary(); - if (token.ExtraProperties != null) - { - var indexOptions = new CreateIndexOptions { Background = true }; + IReadOnlySet propertyNamesWithPrimitiveValueType = + GetPropertyNamesWithPrimitiveValueType(token.ExtraProperties); - foreach (KeyValuePair extraProperty in token.ExtraProperties) - { - _collection.Indexes.CreateOne(new CreateIndexModel( - Builders.IndexKeys.Ascending( - $"{nameof(token.ExtraProperties)}.{extraProperty.Key}"), indexOptions)); - } - } + SetJsonStringValue(token.ExtraProperties, propertyNamesWithPrimitiveValueType); + + await _collection.InsertOneAsync(token, cancellationToken: cancellationToken); + + CreateIndexes(propertyNamesWithPrimitiveValueType); } public async ValueTask TakeOneAsync( @@ -105,5 +103,52 @@ public static void Initialize() { //ensure static constructor is called } + + private void CreateIndexes(IEnumerable extraPropertyNames) + { + var indexOptions = new CreateIndexOptions { Background = true }; + + foreach (string extraPropertyName in extraPropertyNames) + { + _collection.Indexes.CreateOne(new CreateIndexModel( + Builders.IndexKeys.Ascending( + $"{nameof(Token.ExtraProperties)}.{extraPropertyName}"), indexOptions)); + } + } + + private IReadOnlySet GetPropertyNamesWithPrimitiveValueType( + Dictionary extraProperties) + { + HashSet names = new HashSet(); + + foreach (KeyValuePair keyValue in extraProperties) + { + if (keyValue.Value == null) + { + continue; + } + + if (TypeChecker.IsPrimitiveType(keyValue.Value.GetType())) + { + names.Add(keyValue.Key); + } + } + + return names; + } + + private void SetJsonStringValue( + Dictionary extraProperties, IReadOnlySet namesToSkip) + { + foreach(string name in extraProperties.Keys) + { + if (namesToSkip.Contains(name)) + { + continue; + } + + extraProperties[name] = JsonSerializer.Serialize(extraProperties[name]); + } + } } } diff --git a/src/Storage.MongoDB/TypeChecker.cs b/src/Storage.MongoDB/TypeChecker.cs new file mode 100644 index 0000000..5131958 --- /dev/null +++ b/src/Storage.MongoDB/TypeChecker.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using MongoDB.Bson; + +#nullable enable + +namespace Bewit.Storage.MongoDB +{ + public static class TypeChecker + { + private static readonly HashSet PrimitiveTypes = new HashSet + { + typeof(string), + typeof(bool), + typeof(byte), + typeof(sbyte), + typeof(char), + typeof(decimal), + typeof(double), + typeof(float), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(short), + typeof(ushort), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(DateOnly), + typeof(TimeOnly), + typeof(Guid), + typeof(ObjectId), + typeof(BsonObjectId), + typeof(BsonDateTime) + }; + + public static bool IsPrimitiveType(object value) + { + if (value == null) + return false; + + Type type = value.GetType(); + return IsPrimitiveType(type); + } + + public static bool IsPrimitiveType(Type type) + { + Type underlyingType = Nullable.GetUnderlyingType(type) ?? type; + + return PrimitiveTypes.Contains(underlyingType) || + underlyingType.IsEnum; + } + } +} diff --git a/test/IntegrationTests/HotChocolateServer/HCServerHelper.cs b/test/IntegrationTests/HotChocolateServer/HCServerHelper.cs index d3c484b..169bf7a 100644 --- a/test/IntegrationTests/HotChocolateServer/HCServerHelper.cs +++ b/test/IntegrationTests/HotChocolateServer/HCServerHelper.cs @@ -70,7 +70,12 @@ internal static TestServer CreateHotChocolateServer( .Resolve(ctx => { ctx.AddBewitTokenExtraProperties( - new Dictionary { ["foo"] = "bar" }); + new Dictionary + { + ["foo"] = "bar", + ["customType"] = new{}, + ["nullValue"] = null + }); return "http://foo.bar/api/dummy/WithBewitProtection?foo=bar&baz=qux"; })