diff --git a/.gitignore b/.gitignore index 073fbd7..fa8b7e0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ _ReSharper*/ Sources/packages *.nupkg /Sources/.vs/Equ/v15/Server/sqlite3/* +**/.vs \ No newline at end of file diff --git a/.vs/Equ/v16/TestStore/0/000.testlog b/.vs/Equ/v16/TestStore/0/000.testlog new file mode 100644 index 0000000..0fdfab8 Binary files /dev/null and b/.vs/Equ/v16/TestStore/0/000.testlog differ diff --git a/.vs/Equ/v16/TestStore/0/testlog.manifest b/.vs/Equ/v16/TestStore/0/testlog.manifest new file mode 100644 index 0000000..e92ede2 Binary files /dev/null and b/.vs/Equ/v16/TestStore/0/testlog.manifest differ diff --git a/.vs/ProjectSettings.json b/.vs/ProjectSettings.json new file mode 100644 index 0000000..f8b4888 --- /dev/null +++ b/.vs/ProjectSettings.json @@ -0,0 +1,3 @@ +{ + "CurrentProjectSetting": null +} \ No newline at end of file diff --git a/.vs/slnx.sqlite b/.vs/slnx.sqlite new file mode 100644 index 0000000..2dfc3af Binary files /dev/null and b/.vs/slnx.sqlite differ diff --git a/Sources/Equ.Test/NestedClassTests.cs b/Sources/Equ.Test/NestedClassTests.cs new file mode 100644 index 0000000..c814139 --- /dev/null +++ b/Sources/Equ.Test/NestedClassTests.cs @@ -0,0 +1,61 @@ +using System.Collections.Generic; +using Xunit; + +namespace Equ.Test +{ + public class NestedClassTest + { + public class Contained + { + public string BasicProperty { get; set; } + public Contained Nested { get; set; } + } + + public class Container + { + public Contained Nested { get; set; } + } + + public static IEnumerable ShouldDetermineCorrectEqualityTests => new List + { + new object[] { new Container(), new Container(), MemberwiseEqualityComparer.ByProperties, true }, + new object[] { + new Container { Nested = new Contained() }, + new Container { }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Nested = new Contained() }, + new Container { Nested = new Contained() }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Nested = new Contained() }, + new Container { Nested = new Contained() }, + MemberwiseEqualityComparer.ByPropertiesRecursive, + true + }, + new object[] { + new Container { Nested = new Contained { Nested = new Contained() } }, + new Container { Nested = new Contained() }, + MemberwiseEqualityComparer.ByPropertiesRecursive, + false + }, + new object[] { + new Container { Nested = new Contained { Nested = new Contained() } }, + new Container { Nested = new Contained { Nested = new Contained() } }, + MemberwiseEqualityComparer.ByPropertiesRecursive, + true + } + }; + + [Theory] + [MemberData(nameof(ShouldDetermineCorrectEqualityTests))] + public void ShouldDetermineCorrectEquality(Container x, Container y, MemberwiseEqualityComparer equalityComparer, bool expected) + { + Assert.Equal(expected, equalityComparer.Equals(x, y)); + } + } +} diff --git a/Sources/Equ.Test/NestedCollectionsTest.cs b/Sources/Equ.Test/NestedCollectionsTest.cs new file mode 100644 index 0000000..21a3109 --- /dev/null +++ b/Sources/Equ.Test/NestedCollectionsTest.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Equ.Test +{ + public class NestedCollectionsTest + { + public class Level2 + { + public string BasicProperty { get; set; } + } + + public class Level1 + { + public ICollection Items { get; set; } + } + + public class Container + { + public ICollection Items { get; set; } + } + + public static IEnumerable ShouldDetermineCorrectEqualityTests => new List + { + new object[] { new Container(), new Container(), MemberwiseEqualityComparer.ByProperties, true }, + new object[] { + new Container { Items = new List() }, + new Container { }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Items = new List() }, + new Container { Items = new List() }, + MemberwiseEqualityComparer.ByProperties, + true + }, + new object[] { + new Container { Items = new List() }, + new Container { Items = Array.Empty() }, + MemberwiseEqualityComparer.ByProperties, + true + }, + new object[] { + new Container { Items = new [] { new Level1() } }, + new Container { Items = Array.Empty() }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Items = new [] { new Level1() } }, + new Container { Items = new [] { new Level1() } }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Items = new [] { new Level1() } }, + new Container { Items = new [] { new Level1() } }, + MemberwiseEqualityComparer.ByPropertiesRecursive, + true + }, + new object[] { + new Container { Items = new [] { new Level1 { Items = new[] { new Level2() } } } }, + new Container { Items = new [] { new Level1 { Items = new[] { new Level2() } } } }, + MemberwiseEqualityComparer.ByProperties, + false + }, + new object[] { + new Container { Items = new [] { new Level1 { Items = new[] { new Level2() } } } }, + new Container { Items = new [] { new Level1 { Items = new[] { new Level2() } } } }, + MemberwiseEqualityComparer.ByPropertiesRecursive, + true + }, + }; + + [Theory] + [MemberData(nameof(ShouldDetermineCorrectEqualityTests))] + public void ShouldDetermineCorrectEquality(Container x, Container y, MemberwiseEqualityComparer equalityComparer, bool expected) + { + Assert.Equal(expected, equalityComparer.Equals(x, y)); + } + } +} diff --git a/Sources/Equ/ElementwiseSequenceEqualityComparer.cs b/Sources/Equ/ElementwiseSequenceEqualityComparer.cs index 08b2a2a..50b40b9 100644 --- a/Sources/Equ/ElementwiseSequenceEqualityComparer.cs +++ b/Sources/Equ/ElementwiseSequenceEqualityComparer.cs @@ -1,5 +1,6 @@ namespace Equ { + using System; using System.Collections; using System.Collections.Generic; using System.Linq; @@ -13,12 +14,63 @@ /// The type of the enumerable, i.e. a type implementing public class ElementwiseSequenceEqualityComparer : EqualityComparer where T : IEnumerable { + private static readonly Type EnumerableType = typeof(T) + .GetTypeInfo() + .GetInterfaces() + .Where(type => type.GetTypeInfo().IsGenericType && typeof(IEnumerable<>).GetTypeInfo().IsAssignableFrom(type.GetGenericTypeDefinition())) + .SelectMany(type => type.GetTypeInfo().GetGenericArguments()) + .SingleOrDefault(); + + private static readonly MethodInfo SequenceEqualsMethodInfo = (EnumerableType == null) ? null : typeof(Enumerable) + .GetTypeInfo() + .GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(methodInfo => methodInfo.Name.Equals(nameof(Enumerable.SequenceEqual)) && methodInfo.GetParameters().Length == 3) + .Single() + .MakeGenericMethod(EnumerableType); + + private static readonly MethodInfo ScrambledEqualsMethodInfo = (EnumerableType == null) ? null : typeof(ElementwiseSequenceEqualityComparer) + .GetTypeInfo() + .GetMethods(BindingFlags.Static | BindingFlags.NonPublic) + .Where(methodInfo => methodInfo.Name.Equals(nameof(ElementwiseSequenceEqualityComparer.ScrambledEquals)) && methodInfo.IsGenericMethod) + .Single() + .MakeGenericMethod(EnumerableType); + // ReSharper disable once UnusedMember.Global - public new static ElementwiseSequenceEqualityComparer Default => new ElementwiseSequenceEqualityComparer(); + public new static ElementwiseSequenceEqualityComparer Default => new ElementwiseSequenceEqualityComparer(MemberwiseEqualityMode.None); + public static ElementwiseSequenceEqualityComparer ByFieldsRecursive => new ElementwiseSequenceEqualityComparer(MemberwiseEqualityMode.ByFieldsRecursive); + public static ElementwiseSequenceEqualityComparer ByPropertiesRecursive => new ElementwiseSequenceEqualityComparer(MemberwiseEqualityMode.ByPropertiesRecursive); // ReSharper disable once StaticMemberInGenericType private static readonly bool _typeHasDefinedOrder = !IsDictionaryType() && !IsSetType(); - + + private bool _recursive; + private Lazy _memberwiseEqualityComparer; + + public ElementwiseSequenceEqualityComparer() : this(MemberwiseEqualityMode.None) { } + + internal ElementwiseSequenceEqualityComparer(MemberwiseEqualityMode mode) + { + _recursive = (mode == MemberwiseEqualityMode.ByFieldsRecursive || mode == MemberwiseEqualityMode.ByPropertiesRecursive); + _memberwiseEqualityComparer = new Lazy(() => CreateMemberwiseEqualityComparer(mode)); + } + + private object CreateMemberwiseEqualityComparer(MemberwiseEqualityMode mode) + { + var propertyName = mode switch + { + MemberwiseEqualityMode.ByFieldsRecursive => nameof(MemberwiseEqualityComparer.ByFieldsRecursive), + MemberwiseEqualityMode.ByPropertiesRecursive => nameof(MemberwiseEqualityComparer.ByPropertiesRecursive), + _ => nameof(MemberwiseEqualityComparer.ByProperties) + }; + + return typeof(MemberwiseEqualityComparer<>) + .GetTypeInfo() + .MakeGenericType(EnumerableType) + .GetTypeInfo() + .GetProperty(propertyName, BindingFlags.Static | BindingFlags.Public) + .GetValue(null); + } + public override bool Equals(T left, T right) { if (ReferenceEquals(left, right)) @@ -34,10 +86,21 @@ public override bool Equals(T left, T right) return false; } - var leftEnumerable = left.Cast(); - var rightEnumerable = right.Cast(); - - return _typeHasDefinedOrder ? leftEnumerable.SequenceEqual(rightEnumerable) : ScrambledEquals(leftEnumerable, rightEnumerable); + return _typeHasDefinedOrder ? SequenceEqual(left, right) : ScrambledEquals(left, right); + } + + private bool SequenceEqual(IEnumerable left, IEnumerable right) + { + if (_recursive && EnumerableType != null && !EnumerableType.GetTypeInfo().IsPrimitive) + { + return (bool)SequenceEqualsMethodInfo.Invoke(null, new [] { left, right, _memberwiseEqualityComparer.Value }); + } + else + { + var leftEnumerable = left.Cast(); + var rightEnumerable = right.Cast(); + return leftEnumerable.SequenceEqual(rightEnumerable); + } } public override int GetHashCode(T obj) @@ -98,9 +161,24 @@ private static bool IsSetType() return type.IsGenericType && typeof(ISet<>).GetTypeInfo().IsAssignableFrom(type.GetGenericTypeDefinition()); } - private static bool ScrambledEquals(IEnumerable list1, IEnumerable list2) + private bool ScrambledEquals(IEnumerable left, IEnumerable right) + { + if (_recursive && EnumerableType != null && !EnumerableType.GetTypeInfo().IsPrimitive) + { + return (bool)ScrambledEqualsMethodInfo.Invoke(null, new [] { left, right, _memberwiseEqualityComparer.Value }); + } + else + { + var leftEnumerable = left.Cast(); + var rightEnumerable = right.Cast(); + + return ScrambledEquals(leftEnumerable, rightEnumerable, EqualityComparer.Default); + } + } + + private static bool ScrambledEquals(IEnumerable list1, IEnumerable list2, IEqualityComparer equalityComparer) { - var counters = new Dictionary(); + var counters = new Dictionary(equalityComparer); foreach (var element in list1) { if (counters.ContainsKey(element)) diff --git a/Sources/Equ/Equ.csproj b/Sources/Equ/Equ.csproj index fd28f2d..8e4da0e 100644 --- a/Sources/Equ/Equ.csproj +++ b/Sources/Equ/Equ.csproj @@ -16,6 +16,7 @@ See https://github.com/thedmi/Equ/blob/master/ReleaseNotes.md true ..\key.snk + 8.0 \ No newline at end of file diff --git a/Sources/Equ/EqualityFunctionContext.cs b/Sources/Equ/EqualityFunctionContext.cs new file mode 100644 index 0000000..b7b2fe7 --- /dev/null +++ b/Sources/Equ/EqualityFunctionContext.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; + +namespace Equ +{ + public class EqualityFunctionContext + { + private readonly Dictionary _equalityComparers; + + public EqualityFunctionContext(MemberwiseEqualityMode mode) + { + _equalityComparers = new Dictionary(); + + Mode = mode; + + MemberwiseEqualityComparerProperty = mode switch + { + MemberwiseEqualityMode.ByFields => nameof(MemberwiseEqualityComparer.ByFields), + MemberwiseEqualityMode.ByFieldsRecursive => nameof(MemberwiseEqualityComparer.ByFieldsRecursive), + MemberwiseEqualityMode.ByProperties => nameof(MemberwiseEqualityComparer.ByProperties), + MemberwiseEqualityMode.ByPropertiesRecursive => nameof(MemberwiseEqualityComparer.ByPropertiesRecursive), + _ => string.Empty + }; + + ElementwiseSequenceEqualityComparerProperty = mode switch + { + MemberwiseEqualityMode.ByFieldsRecursive => nameof(ElementwiseSequenceEqualityComparer>.ByFieldsRecursive), + MemberwiseEqualityMode.ByPropertiesRecursive => nameof(ElementwiseSequenceEqualityComparer>.ByPropertiesRecursive), + _ => nameof(ElementwiseSequenceEqualityComparer>.Default) + }; + } + + public bool TryGetEqualityComparer(Type type, out MemberwiseEqualityComparer equalityComparer) + { + return _equalityComparers.TryGetValue(type, out equalityComparer); + } + + public void Add(MemberwiseEqualityComparer equalityComparer) + { + _equalityComparers.Add(typeof(T), equalityComparer); + } + + public MemberwiseEqualityMode Mode { get; } + public bool IsRecursive => Mode == MemberwiseEqualityMode.ByFieldsRecursive || Mode == MemberwiseEqualityMode.ByPropertiesRecursive; + public string MemberwiseEqualityComparerProperty { get; } + public string ElementwiseSequenceEqualityComparerProperty { get; } + } +} diff --git a/Sources/Equ/EqualityFunctionGenerator.cs b/Sources/Equ/EqualityFunctionGenerator.cs index 26e39bb..8e019ae 100644 --- a/Sources/Equ/EqualityFunctionGenerator.cs +++ b/Sources/Equ/EqualityFunctionGenerator.cs @@ -16,8 +16,8 @@ public class EqualityFunctionGenerator private readonly Type _type; private readonly Func> _fieldSelector; - private readonly Func> _propertySelector; + private readonly EqualityFunctionContext _context; /// /// Creates a generator for type . The generated functions will consider all fields @@ -27,11 +27,12 @@ public class EqualityFunctionGenerator /// /// /// - public EqualityFunctionGenerator(Type type, Func> fieldSelector, Func> propertySelector) + public EqualityFunctionGenerator(Type type, Func> fieldSelector, Func> propertySelector, EqualityFunctionContext context) { _type = type; _fieldSelector = fieldSelector; _propertySelector = propertySelector; + _context = context; } /// @@ -47,7 +48,7 @@ public Func MakeGetHashCodeMethod() var objParam = Expression.Convert(objRaw, _type); // compound XOR expression - var getHashCodeExprs = GetIncludedMembers(_type).Select(mi => MakeGetHashCodeExpression(mi, objParam)); + var getHashCodeExprs = GetIncludedMembers(_type).Select(mi => MakeGetHashCodeExpression(mi, objParam, _context)); var xorChainExpr = getHashCodeExprs.Aggregate((Expression)Expression.Constant(29), LinkHashCodeExpression); return Expression.Lambda>(xorChainExpr, objRaw).Compile(); @@ -68,7 +69,7 @@ public Func MakeEqualsMethod() var rightParam = Expression.Convert(rightRaw, _type); // AND expression using short-circuit evaluation - var equalsExprs = GetIncludedMembers(_type).Select(mi => MakeEqualsExpression(mi, leftParam, rightParam)); + var equalsExprs = GetIncludedMembers(_type).Select(mi => MakeEqualsExpression(mi, leftParam, rightParam, _context)); var andChainExpr = equalsExprs.Aggregate((Expression)Expression.Constant(true), Expression.AndAlso); // call Object.Equals if second parameter doesn't match type @@ -81,6 +82,13 @@ public Func MakeEqualsMethod() return Expression.Lambda>(useTypedEqualsExpression, leftRaw, rightRaw).Compile(); } + public Func MakeEqualsMethod(MemberwiseEqualityComparer equalityComparer) + { + _context.Add(equalityComparer); + + return MakeEqualsMethod(); + } + private IEnumerable GetIncludedMembers(Type type) { return _fieldSelector(type).Cast().Concat(_propertySelector(type)); @@ -92,7 +100,7 @@ private static Expression LinkHashCodeExpression(Expression left, Expression rig return Expression.ExclusiveOr(leftMultiplied, right); } - private static Expression MakeEqualsExpression(MemberInfo member, Expression left, Expression right) + private static Expression MakeEqualsExpression(MemberInfo member, Expression left, Expression right, EqualityFunctionContext context) { var leftMemberExpr = Expression.MakeMemberAccess(left, member); var rightMemberExpr = Expression.MakeMemberAccess(right, member); @@ -103,25 +111,32 @@ private static Expression MakeEqualsExpression(MemberInfo member, Expression lef { var boxedLeftMemberExpr = Expression.Convert(leftMemberExpr, typeof(object)); var boxedRightMemberExpr = Expression.Convert(rightMemberExpr, typeof(object)); - return MakeReferenceTypeEqualExpression(boxedLeftMemberExpr, boxedRightMemberExpr); + return MakeObjectEqualsExpression(boxedLeftMemberExpr, boxedRightMemberExpr); } return ReflectionUtils.IsSequenceType(memberType) - ? MakeSequenceTypeEqualExpression(leftMemberExpr, rightMemberExpr, memberType) - : MakeReferenceTypeEqualExpression(leftMemberExpr, rightMemberExpr); + ? MakeSequenceTypeEqualExpression(leftMemberExpr, rightMemberExpr, memberType, context) + : MakeReferenceTypeEqualExpression(leftMemberExpr, rightMemberExpr, memberType, context); } - private static Expression MakeSequenceTypeEqualExpression(Expression left, Expression right, Type enumerableType) + private static Expression MakeSequenceTypeEqualExpression(Expression left, Expression right, Type enumerableType, EqualityFunctionContext context) { - return MakeCallOnSequenceEqualityComparerExpression("Equals", enumerableType, left, right); + return MakeCallOnSequenceEqualityComparerExpression("Equals", enumerableType, context, left, right); } - private static Expression MakeReferenceTypeEqualExpression(Expression left, Expression right) + private static Expression MakeReferenceTypeEqualExpression(Expression left, Expression right, Type memberType, EqualityFunctionContext context) + { + return context.IsRecursive + ? MakeCallOnMemberwiseEqualityComparerExpression("Equals", memberType, context, left, right) + : MakeObjectEqualsExpression(left, right); + } + + private static Expression MakeObjectEqualsExpression(Expression left, Expression right) { return Expression.Call(_objectEqualsMethod, left, right); } - private static Expression MakeGetHashCodeExpression(MemberInfo member, Expression obj) + private static Expression MakeGetHashCodeExpression(MemberInfo member, Expression obj, EqualityFunctionContext context) { var memberAccessExpr = Expression.MakeMemberAccess(obj, member); var memberAccessAsObjExpr = Expression.Convert(memberAccessExpr, typeof(object)); @@ -129,7 +144,7 @@ private static Expression MakeGetHashCodeExpression(MemberInfo member, Expressio var memberType = memberAccessExpr.Type; var getHashCodeExpr = ReflectionUtils.IsSequenceType(memberType) - ? MakeCallOnSequenceEqualityComparerExpression("GetHashCode", memberType, memberAccessExpr) + ? MakeCallOnSequenceEqualityComparerExpression("GetHashCode", memberType, context, memberAccessExpr) : Expression.Call(memberAccessAsObjExpr, "GetHashCode", Type.EmptyTypes); return Expression.Condition( @@ -138,13 +153,27 @@ private static Expression MakeGetHashCodeExpression(MemberInfo member, Expressio getHashCodeExpr); // Return the actual getHashCode call } - private static Expression MakeCallOnSequenceEqualityComparerExpression(string methodName, Type enumerableType, params Expression[] parameterExpressions) + private static Expression MakeCallOnSequenceEqualityComparerExpression(string methodName, Type enumerableType, EqualityFunctionContext context, params Expression[] parameterExpressions) { var comparerType = typeof(ElementwiseSequenceEqualityComparer<>).MakeGenericType(enumerableType); - var comparerInstance = comparerType.GetTypeInfo().GetProperty("Default", BindingFlags.Static | BindingFlags.Public).GetValue(null); + var comparerInstance = comparerType.GetTypeInfo().GetProperty(context.ElementwiseSequenceEqualityComparerProperty, BindingFlags.Static | BindingFlags.Public).GetValue(null); var comparerExpr = Expression.Constant(comparerInstance); return Expression.Call(comparerExpr, methodName, Type.EmptyTypes, parameterExpressions); } + + private static Expression MakeCallOnMemberwiseEqualityComparerExpression(string methodName, Type memberType, EqualityFunctionContext context, params Expression[] parameterExpressions) + { + if (!context.TryGetEqualityComparer(memberType, out MemberwiseEqualityComparer comparerInstance)) + { + var equalityGenerator = new EqualityFunctionGenerator(memberType, t => new List(), MemberwiseEqualityComparer.AllPropertiesExceptIgnored, context); + var comparerType = typeof(MemberwiseEqualityComparer<>).MakeGenericType(memberType); + + comparerInstance = (MemberwiseEqualityComparer) comparerType.GetTypeInfo().GetMethod(nameof(MemberwiseEqualityComparer.Custom), BindingFlags.Static | BindingFlags.Public).Invoke(null, new[] { equalityGenerator }); + } + + var comparerExpr = Expression.Constant(comparerInstance); + return Expression.Call(comparerExpr, methodName, Type.EmptyTypes, parameterExpressions); + } } } \ No newline at end of file diff --git a/Sources/Equ/MemberwiseEqualityComparer.cs b/Sources/Equ/MemberwiseEqualityComparer.cs index 06a22f2..203642f 100644 --- a/Sources/Equ/MemberwiseEqualityComparer.cs +++ b/Sources/Equ/MemberwiseEqualityComparer.cs @@ -5,6 +5,39 @@ using System.Linq; using System.Reflection; + public class MemberwiseEqualityComparer + { + private static BindingFlags AllInstanceMembers => BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + + private static bool IsNotMarkedAsIgnore(MemberInfo memberInfo) + { + var isSelfIgnored = memberInfo.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Any(); + + var propertyInfo = ReflectionUtils.GetPropertyForBackingField(memberInfo); + var isPropertyIgnored = propertyInfo != null + && propertyInfo.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Any(); + + return !isSelfIgnored && !isPropertyIgnored; + } + + private static bool IsNotIndexed(PropertyInfo propertyInfo) + { + var indexParamaters = propertyInfo.GetIndexParameters(); + + return !indexParamaters.Any(); + } + + public static IEnumerable AllFieldsExceptIgnored(Type t) + { + return t.GetTypeInfo().GetFields(AllInstanceMembers).Where(IsNotMarkedAsIgnore); + } + + public static IEnumerable AllPropertiesExceptIgnored(Type t) + { + return t.GetTypeInfo().GetProperties(AllInstanceMembers).Where(info => IsNotMarkedAsIgnore(info) && IsNotIndexed(info)); + } + } + /// /// Provides an implementation of that performs memberwise /// equality comparison of objects of type T. Use the or @@ -14,33 +47,59 @@ /// For more advanced scenarios, use the creator and pass a /// that matches your requirements. /// - public class MemberwiseEqualityComparer : IEqualityComparer + public class MemberwiseEqualityComparer : MemberwiseEqualityComparer, IEqualityComparer { private readonly Func _equalsFunc; private readonly Func _getHashCodeFunc; - private static readonly Lazy> _fieldsComparer = + private static readonly Lazy> _defaultFieldsComparer = + new Lazy>( + () => + new MemberwiseEqualityComparer( + new EqualityFunctionGenerator( + typeof(T), + AllFieldsExceptIgnored, + t => new List(), + new EqualityFunctionContext(MemberwiseEqualityMode.ByFields)))); + + private static readonly Lazy> _defaultPropertiesComparer = + new Lazy>( + () => + new MemberwiseEqualityComparer( + new EqualityFunctionGenerator( + typeof(T), + t => new List(), + AllPropertiesExceptIgnored, + new EqualityFunctionContext(MemberwiseEqualityMode.ByProperties)))); + + private static readonly Lazy> _recursiveFieldsComparer = new Lazy>( () => new MemberwiseEqualityComparer( new EqualityFunctionGenerator( typeof(T), AllFieldsExceptIgnored, - t => new List()))); + t => new List(), + new EqualityFunctionContext(MemberwiseEqualityMode.ByFieldsRecursive)))); - private static readonly Lazy> _propertiesComparer = + private static readonly Lazy> _recursivePropertiesComparer = new Lazy>( () => new MemberwiseEqualityComparer( new EqualityFunctionGenerator( typeof(T), t => new List(), - AllPropertiesExceptIgnored))); + AllPropertiesExceptIgnored, + new EqualityFunctionContext(MemberwiseEqualityMode.ByPropertiesRecursive)))); + + public static MemberwiseEqualityComparer ByFields => _defaultFieldsComparer.Value; - public static MemberwiseEqualityComparer ByFields => _fieldsComparer.Value; + public static MemberwiseEqualityComparer ByProperties => _defaultPropertiesComparer.Value; - public static MemberwiseEqualityComparer ByProperties => _propertiesComparer.Value; + public static MemberwiseEqualityComparer ByFieldsRecursive => _recursiveFieldsComparer.Value; + + public static MemberwiseEqualityComparer ByPropertiesRecursive => _recursivePropertiesComparer.Value; public static MemberwiseEqualityComparer Custom(EqualityFunctionGenerator equalityFunctionGenerator) { @@ -49,40 +108,10 @@ public static MemberwiseEqualityComparer Custom(EqualityFunctionGenerator equ private MemberwiseEqualityComparer(EqualityFunctionGenerator equalityFunctionGenerator) { - _equalsFunc = equalityFunctionGenerator.MakeEqualsMethod(); + _equalsFunc = equalityFunctionGenerator.MakeEqualsMethod(this); _getHashCodeFunc = equalityFunctionGenerator.MakeGetHashCodeMethod(); } - private static IEnumerable AllFieldsExceptIgnored(Type t) - { - return t.GetTypeInfo().GetFields(AllInstanceMembers).Where(IsNotMarkedAsIgnore); - } - - private static IEnumerable AllPropertiesExceptIgnored(Type t) - { - return t.GetTypeInfo().GetProperties(AllInstanceMembers).Where(info => IsNotMarkedAsIgnore(info) && IsNotIndexed(info)); - } - - private static BindingFlags AllInstanceMembers => BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; - - private static bool IsNotMarkedAsIgnore(MemberInfo memberInfo) - { - var isSelfIgnored = memberInfo.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Any(); - - var propertyInfo = ReflectionUtils.GetPropertyForBackingField(memberInfo); - var isPropertyIgnored = propertyInfo != null - && propertyInfo.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Any(); - - return !isSelfIgnored && !isPropertyIgnored; - } - - private static bool IsNotIndexed(PropertyInfo propertyInfo) - { - var indexParamaters = propertyInfo.GetIndexParameters(); - - return !indexParamaters.Any(); - } - /// /// This method delegates to the generated equality function. Note that first a reference check /// on and (reference equality and null-check) is performed. diff --git a/Sources/Equ/MemberwiseEqualityMode.cs b/Sources/Equ/MemberwiseEqualityMode.cs new file mode 100644 index 0000000..35833e5 --- /dev/null +++ b/Sources/Equ/MemberwiseEqualityMode.cs @@ -0,0 +1,11 @@ +namespace Equ +{ + public enum MemberwiseEqualityMode + { + None, + ByFields, + ByFieldsRecursive, + ByProperties, + ByPropertiesRecursive + } +}