Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions src/MongoDB.Driver/FieldValueSerializerHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
using MongoDB.Bson;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Support;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;

namespace MongoDB.Driver
{
Expand Down Expand Up @@ -63,7 +63,7 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer
var fieldSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(fieldType);

// synthesize a NullableSerializer using the field serializer
if (valueType.IsNullable() && valueType.GetNullableUnderlyingType() == fieldType)
if (valueType.IsNullable(out var nonNullableValueType) && nonNullableValueType == fieldType)
{
var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(fieldType);
var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { fieldSerializerInterfaceType });
Expand All @@ -80,24 +80,21 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer
return (IBsonSerializer)enumConvertingSerializerConstructor.Invoke(new object[] { fieldSerializer });
}

if (valueType.IsNullable() && valueType.GetNullableUnderlyingType().IsConvertibleToEnum())
if (valueType.IsNullable(out nonNullableValueType) && nonNullableValueType.IsConvertibleToEnum())
{
var underlyingValueType = valueType.GetNullableUnderlyingType();
var underlyingValueSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(underlyingValueType);
var enumConvertingSerializerType = typeof(EnumConvertingSerializer<,>).MakeGenericType(underlyingValueType, fieldType);
var nonNullableValueSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(nonNullableValueType);
var enumConvertingSerializerType = typeof(EnumConvertingSerializer<,>).MakeGenericType(nonNullableValueType, fieldType);
var enumConvertingSerializerConstructor = enumConvertingSerializerType.GetTypeInfo().GetConstructor(new[] { fieldSerializerInterfaceType });
var enumConvertingSerializer = enumConvertingSerializerConstructor.Invoke(new object[] { fieldSerializer });
var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(underlyingValueType);
var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { underlyingValueSerializerInterfaceType });
var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(nonNullableValueType);
var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { nonNullableValueSerializerInterfaceType });
return (IBsonSerializer)nullableSerializerConstructor.Invoke(new object[] { enumConvertingSerializer });
}
}

// synthesize a NullableEnumConvertingSerializer using the field serializer
if (fieldType.IsNullableEnum() && valueType.IsNullable())
if (fieldType.IsNullableEnum(out var nonNullableFieldType) && valueType.IsNullable(out nonNullableValueType))
{
var nonNullableFieldType = fieldType.GetNullableUnderlyingType();
var nonNullableValueType = valueType.GetNullableUnderlyingType();
var nonNullableFieldSerializer = ((IChildSerializerConfigurable)fieldSerializer).ChildSerializer;
var nonNullableFieldSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(nonNullableFieldType);
var nullableEnumConvertingSerializerType = typeof(NullableEnumConvertingSerializer<,>).MakeGenericType(nonNullableValueType, nonNullableFieldType);
Expand All @@ -106,18 +103,15 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer
}

// synthesize an IEnumerableSerializer serializer using the item serializer from the field serializer
Type fieldIEnumerableInterfaceType;
Type valueIEnumerableInterfaceType;
Type itemType;
if (
(fieldIEnumerableInterfaceType = fieldType.FindIEnumerable()) != null &&
(valueIEnumerableInterfaceType = valueType.FindIEnumerable()) != null &&
(itemType = fieldIEnumerableInterfaceType.GetSequenceElementType()) == valueIEnumerableInterfaceType.GetSequenceElementType() &&
fieldType.ImplementsIEnumerable(out var fieldItemType) &&
valueType.ImplementsIEnumerable(out var valueItemType) &&
fieldItemType == valueItemType &&
fieldSerializer is IChildSerializerConfigurable)
{
var itemSerializer = ((IChildSerializerConfigurable)fieldSerializer).ChildSerializer;
var itemSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(itemType);
var ienumerableSerializerType = typeof(IEnumerableSerializer<>).MakeGenericType(itemType);
var itemSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(fieldItemType);
var ienumerableSerializerType = typeof(IEnumerableSerializer<>).MakeGenericType(fieldItemType);
var ienumerableSerializerConstructor = ienumerableSerializerType.GetTypeInfo().GetConstructor(new[] { itemSerializerInterfaceType });
return (IBsonSerializer)ienumerableSerializerConstructor.Invoke(new object[] { itemSerializer });
}
Expand Down
175 changes: 110 additions & 65 deletions src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are here, could we also remove the unnecessary else blocks in some of the methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I noticed that the methods: TryGetIEnumerableGenericInterface; TryGetIListGenericInterface; TryGetIQueryableGenericInterface all have nearly identical implementations. Could they be consolidated to use the existing TryGetGenericInterface helper?

something like:

  public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface)
  {
      return TryGetGenericInterface(type, new[] { typeof(IEnumerable<>) }, out ienumerableGenericInterface);
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done except slightly differently. I added an overload of TryGetGenerricInterface that only takes a single generic interface definition to match against. Should be slightly more efficient than the array version when there is only one definition involved.

Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using MongoDB.Bson;

namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
{
internal static class TypeExtensions
{
private static readonly Type[] __dictionaryInterfaces =
private static readonly Type[] __dictionaryInterfaceDefinitions =
{
typeof(IDictionary<,>),
typeof(IReadOnlyDictionary<,>)
Expand Down Expand Up @@ -52,6 +54,14 @@ internal static class TypeExtensions
typeof(ValueTuple<,,,,,,,>)
};

public static object GetDefaultValue(this Type type)
{
var genericMethod = typeof(TypeExtensions)
.GetMethod(nameof(GetDefaultValueGeneric), BindingFlags.NonPublic | BindingFlags.Static)
.MakeGenericMethod(type);
return genericMethod.Invoke(null, null);
}

public static Type GetIEnumerableGenericInterface(this Type enumerableType)
{
if (enumerableType.TryGetIEnumerableGenericInterface(out var ienumerableGenericInterface))
Expand Down Expand Up @@ -92,7 +102,7 @@ public static bool Implements(this Type type, Type @interface)

public static bool ImplementsDictionaryInterface(this Type type, out Type keyType, out Type valueType)
{
if (TryGetGenericInterface(type, __dictionaryInterfaces, out var dictionaryInterface))
if (TryGetGenericInterface(type, __dictionaryInterfaceDefinitions, out var dictionaryInterface))
{
var genericArguments = dictionaryInterface.GetGenericArguments();
keyType = genericArguments[0];
Expand Down Expand Up @@ -136,6 +146,18 @@ public static bool ImplementsIList(this Type type, out Type itemType)
return false;
}

public static bool ImplementsIQueryable(this Type type, out Type itemType)
{
if (TryGetIQueryableGenericInterface(type, out var iqueryableType))
{
itemType = iqueryableType.GetGenericArguments()[0];
return true;
}

itemType = null;
return false;
}

public static bool Is(this Type type, Type comparand)
{
if (type == comparand)
Expand Down Expand Up @@ -175,41 +197,49 @@ public static bool IsArray(this Type type, out Type itemType)
return false;
}

public static bool IsBooleanOrNullableBoolean(this Type type)
{
return
type == typeof(bool) ||
type.IsNullable(out var valueType) && valueType == typeof(bool);
}

public static bool IsConvertibleToEnum(this Type type)
{
return
type == typeof(sbyte) ||
type == typeof(short) ||
type == typeof(int) ||
type == typeof(long) ||
type == typeof(byte) ||
type == typeof(ushort) ||
type == typeof(uint) ||
type == typeof(ulong) ||
type == typeof(Enum) ||
type == typeof(string);
}

public static bool IsEnum(this Type type, out Type underlyingType)
{
if (type.IsEnum)
{
underlyingType = Enum.GetUnderlyingType(type);
return true;
}
else
{
underlyingType = null;
return false;
}

underlyingType = null;
return false;
}

public static bool IsEnum(this Type type, out Type enumType, out Type underlyingType)
public static bool IsEnumOrNullableEnum(this Type type, out Type enumType, out Type underlyingType)
{
if (type.IsEnum)
if (type.IsEnum(out underlyingType))
{
enumType = type;
underlyingType = Enum.GetUnderlyingType(type);
return true;
}
else
{
enumType = null;
underlyingType = null;
return false;
}
}

public static bool IsEnumOrNullableEnum(this Type type, out Type enumType, out Type underlyingType)
{
return
type.IsEnum(out enumType, out underlyingType) ||
type.IsNullableEnum(out enumType, out underlyingType);
return IsNullableEnum(type, out enumType, out underlyingType);
}

public static bool IsNullable(this Type type)
Expand All @@ -224,23 +254,39 @@ public static bool IsNullable(this Type type, out Type valueType)
valueType = type.GetGenericArguments()[0];
return true;
}
else
{
valueType = null;
return false;
}

valueType = null;
return false;
}

public static bool IsNullableEnum(this Type type)
{
return type.IsNullable(out var valueType) && valueType.IsEnum;
}

public static bool IsNullableEnum(this Type type, out Type enumType)
{
if (type.IsNullable(out var valueType) && valueType.IsEnum)
{
enumType = valueType;
return true;
}

enumType = null;
return false;
}

public static bool IsNullableEnum(this Type type, out Type enumType, out Type underlyingType)
{
if (type.IsNullable(out var valueType) && valueType.IsEnum(out underlyingType))
{
enumType = valueType;
return true;
}

enumType = null;
underlyingType = null;
return type.IsNullable(out var valueType) && valueType.IsEnum(out enumType, out underlyingType);
return false;
}

public static bool IsNullableOf(this Type type, Type valueType)
Expand All @@ -256,6 +302,24 @@ public static bool IsReadOnlySpanOf(this Type type, Type itemType)
type.GetGenericArguments()[0] == itemType;
}

public static bool IsNumeric(this Type type)
{
return
type == typeof(int) ||
type == typeof(long) ||
type == typeof(double) ||
type == typeof(float) ||
type == typeof(decimal) ||
type == typeof(Decimal128);
}

public static bool IsNumericOrNullableNumeric(this Type type)
{
return
type.IsNumeric() ||
type.IsNullable(out var valueType) && valueType.IsNumeric();
}

public static bool IsSameAsOrNullableOf(this Type type, Type valueType)
{
return type == valueType || type.IsNullableOf(valueType);
Expand Down Expand Up @@ -298,55 +362,36 @@ public static bool IsValueTuple(this Type type)
__valueTupleTypeDefinitions.Contains(typeDefinition);
}

public static bool TryGetGenericInterface(this Type type, Type[] interfaceDefinitions, out Type genericInterface)
public static bool TryGetGenericInterface(this Type type, Type genericInterfaceDefintion, out Type genericInterface)
{
genericInterface =
type.IsConstructedGenericType && interfaceDefinitions.Contains(type.GetGenericTypeDefinition()) ?
type.IsConstructedGenericType && type.GetGenericTypeDefinition() == genericInterfaceDefintion ?
type :
type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && interfaceDefinitions.Contains(i.GetGenericTypeDefinition()));
type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && i.GetGenericTypeDefinition() == genericInterfaceDefintion);
return genericInterface != null;
}

public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface)
public static bool TryGetGenericInterface(this Type type, Type[] genericInterfaceDefinitions, out Type genericInterface)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
ienumerableGenericInterface = type;
return true;
}

foreach (var interfaceType in type.GetInterfaces())
{
if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
ienumerableGenericInterface = interfaceType;
return true;
}
}

ienumerableGenericInterface = null;
return false;
genericInterface =
type.IsConstructedGenericType && genericInterfaceDefinitions.Contains(type.GetGenericTypeDefinition()) ?
type :
type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && genericInterfaceDefinitions.Contains(i.GetGenericTypeDefinition()));
return genericInterface != null;
}

public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface)
=> TryGetGenericInterface(type, typeof(IEnumerable<>), out ienumerableGenericInterface);

public static bool TryGetIListGenericInterface(this Type type, out Type ilistGenericInterface)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IList<>))
{
ilistGenericInterface = type;
return true;
}
=> TryGetGenericInterface(type, typeof(IList<>), out ilistGenericInterface);

foreach (var interfaceType in type.GetInterfaces())
{
if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IList<>))
{
ilistGenericInterface = interfaceType;
return true;
}
}
public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface)
=> TryGetGenericInterface(type, typeof(IQueryable<>), out iqueryableGenericInterface);

ilistGenericInterface = null;
return false;
private static TValue GetDefaultValueGeneric<TValue>()
{
return default(TValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
using MongoDB.Bson;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Core.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecutableQueryTranslators;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation
{
Expand Down Expand Up @@ -105,7 +105,11 @@ internal MongoQueryProvider(
// public methods
public override IQueryable CreateQuery(Expression expression)
{
var outputType = expression.Type.GetSequenceElementType();
if (!expression.Type.ImplementsIQueryable(out var outputType))
{
throw new ExpressionNotSupportedException(expression, because: "expression type does not implement IQueryable");
}

var queryType = typeof(MongoQuery<,>).MakeGenericType(typeof(TDocument), outputType);
return (IQueryable)Activator.CreateInstance(queryType, new object[] { this, expression });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
Expand Down
Loading