From 38f560a5988ecb40c532d37fb07a1c1b75039b18 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:41:23 -0400 Subject: [PATCH 01/15] Add enumerable extensions for median and percentile --- src/MongoDB.Driver/Linq/MongoEnumerable.cs | 462 ++++++++++++++++++++- 1 file changed, 460 insertions(+), 2 deletions(-) diff --git a/src/MongoDB.Driver/Linq/MongoEnumerable.cs b/src/MongoDB.Driver/Linq/MongoEnumerable.cs index 787da860a4e..f3ba9c7d4c5 100644 --- a/src/MongoDB.Driver/Linq/MongoEnumerable.cs +++ b/src/MongoDB.Driver/Linq/MongoEnumerable.cs @@ -16,8 +16,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Misc; @@ -238,6 +236,226 @@ public static IEnumerable MaxN( throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double? Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double? Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double? Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double? Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. + /// The median value. + public static double? Median(this IEnumerable source, Func selector) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double? Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double? Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double? Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double? Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes the median of a sequence of values. + /// + /// The sequence of values. + /// The median value. + public static double? Median(this IEnumerable source) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + /// /// Returns the min n results. /// @@ -275,6 +493,246 @@ public static IEnumerable MinN( throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// The type of the elements of . + /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + + /// + /// Computes multiple percentiles of a sequence of values. + /// + /// A sequence of values to calculate the percentiles of. + /// The percentiles to compute (each between 0.0 and 1.0). + /// The percentiles of the sequence of values. + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + { + throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); + } + /// /// Computes the population standard deviation of a sequence of values. /// From fd1c493dd85ce745ef235441d699d654991390bf Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:43:30 -0400 Subject: [PATCH 02/15] add setWindowFields extensions for median and percentile --- .../ISetWindowFieldsPartitionExtensions.cs | 271 +++++++++++++++++- 1 file changed, 270 insertions(+), 1 deletion(-) diff --git a/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs b/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs index ca8aade374b..efc42a0cd48 100644 --- a/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs +++ b/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; -using MongoDB.Bson; namespace MongoDB.Driver.Linq { @@ -879,6 +878,136 @@ public static TValue Max(this ISetWindowFieldsPartition throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the median of the numeric values. Median ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The window boundaries. + /// The median of the selected values. + public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + /// /// Returns the minimum value. /// @@ -893,6 +1022,146 @@ public static TValue Min(this ISetWindowFieldsPartition throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + + /// + /// Returns the values at the given percentiles. Percentile ignores non-numeric values. Percentile returns results in the same order as the given percentiles. + /// + /// The type of the input documents in the partition. + /// The partition. + /// The selector that selects a value from the input document. + /// The percentiles (between 0.0 and 1.0). + /// The window boundaries. + /// The values at the given percentiles. + public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + { + throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); + } + /// /// Returns a sequence of values. /// From 2ade9cdeda6903f1789863cc61b40816ce32f228 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:45:27 -0400 Subject: [PATCH 03/15] Add AstExpression support for median and percentile accumulators/window expressions --- .../Linq3Implementation/Ast/AstNodeType.cs | 5 ++ .../AstComplexAccumulatorExpression.cs | 66 +++++++++++++++++ .../AstComplexAccumulatorOperator.cs | 38 ++++++++++ .../Ast/Expressions/AstExpression.cs | 25 +++++++ .../Ast/Expressions/AstMedianExpression.cs | 64 ++++++++++++++++ .../Expressions/AstMedianWindowExpression.cs | 69 +++++++++++++++++ .../Expressions/AstPercentileExpression.cs | 68 +++++++++++++++++ .../AstPercentileWindowExpression.cs | 74 +++++++++++++++++++ .../Ast/Visitors/AstNodeVisitor.cs | 38 ++++++++++ 9 files changed, 447 insertions(+) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianWindowExpression.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileExpression.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileWindowExpression.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs index 2b8ce448dd3..bc46e23f5a7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs @@ -31,6 +31,7 @@ internal enum AstNodeType BucketStage, CollStatsStage, ComparisonFilterOperation, + ComplexAccumulatorExpression, ComputedArrayExpression, ComputedDocumentExpression, ComputedField, @@ -93,6 +94,8 @@ internal enum AstNodeType MatchesEverythingFilter, MatchesNothingFilter, MatchStage, + MedianExpression, + MedianWindowExpression, MergeStage, ModFilterOperation, NaryExpression, @@ -104,6 +107,8 @@ internal enum AstNodeType NullaryWindowExpression, OrFilter, OutStage, + PercentileExpression, + PercentileWindowExpression, PickAccumulatorExpression, PickExpression, Pipeline, diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs new file mode 100644 index 00000000000..d825dac5f1d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs @@ -0,0 +1,66 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstComplexAccumulatorExpression : AstAccumulatorExpression + { + private readonly AstComplexAccumulatorOperator _operator; + private readonly Dictionary _args; + + public AstComplexAccumulatorExpression(AstComplexAccumulatorOperator @operator, Dictionary args) + { + _operator = @operator; + _args = Ensure.IsNotNull(args, nameof(args)); + } + + public Dictionary Args => _args; + + public override AstNodeType NodeType => AstNodeType.ComplexAccumulatorExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitComplexAccumulatorExpression(this); + } + + public override BsonValue Render() + { + var document = new BsonDocument(); + + // Add all accumulator parameters + foreach (var kvp in _args) + { + document[kvp.Key] = kvp.Value.Render(); + } + + return new BsonDocument(_operator.Render(), document); + } + + public AstComplexAccumulatorExpression Update(Dictionary args) + { + if (ReferenceEquals(args, _args)) + { + return this; + } + + return new AstComplexAccumulatorExpression(_operator, args); + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs new file mode 100644 index 00000000000..01f34827135 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs @@ -0,0 +1,38 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal enum AstComplexAccumulatorOperator + { + Median, + Percentile + } + + internal static class AstComplexAccumulatorOperatorExtensions + { + public static string Render(this AstComplexAccumulatorOperator @operator) + { + return @operator switch + { + AstComplexAccumulatorOperator.Median => "$median", + AstComplexAccumulatorOperator.Percentile => "$percentile", + _ => throw new InvalidOperationException($"Unexpected complex accumulator operator: {@operator}.") + }; + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs index 90512554fc6..b5f6381f9ff 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs @@ -209,6 +209,11 @@ public static AstExpression Comparison(AstBinaryOperator comparisonOperator, Ast }; } + public static AstComplexAccumulatorExpression ComplexAccumulator(AstComplexAccumulatorOperator @operator, Dictionary args) + { + return new AstComplexAccumulatorExpression(@operator, args); + } + public static AstExpression ComputedArray(IEnumerable items) { return new AstComputedArrayExpression(items); @@ -597,6 +602,16 @@ public static AstExpression Max(AstExpression arg1, AstExpression arg2) return new AstNaryExpression(AstNaryOperator.Max, [arg1, arg2]); } + public static AstExpression Median(AstExpression input) + { + return new AstMedianExpression(input); + } + + public static AstMedianWindowExpression MedianWindowExpression(AstExpression input, AstWindow window) + { + return new AstMedianWindowExpression(input, window); + } + public static AstExpression Min(AstExpression array) { return new AstUnaryExpression(AstUnaryOperator.Min, array); @@ -653,6 +668,16 @@ public static AstExpression Or(params AstExpression[] args) return new AstNaryExpression(AstNaryOperator.Or, flattenedArgs); } + public static AstPercentileExpression Percentile(AstExpression input, AstExpression percentiles) + { + return new AstPercentileExpression(input, percentiles); + } + + public static AstPercentileWindowExpression PercentileWindowExpression(AstExpression input, AstExpression percentiles, AstWindow window) + { + return new AstPercentileWindowExpression(input, percentiles, window); + } + public static AstExpression PickExpression(AstPickOperator @operator, AstExpression source, AstSortFields sortBy, AstVarExpression @as, AstExpression selector, AstExpression n) { return new AstPickExpression(@operator, source, sortBy, @as, selector, n); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs new file mode 100644 index 00000000000..2bbbf0112ec --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs @@ -0,0 +1,64 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstMedianExpression : AstExpression + { + private readonly AstExpression _input; + + public AstMedianExpression(AstExpression input) + { + _input = Ensure.IsNotNull(input, nameof(input)); + } + + public AstExpression Input => _input; + + public override AstNodeType NodeType => AstNodeType.MedianExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitMedianExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$median", new BsonDocument + { + { "input", _input.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + } + }; + } + + public AstMedianExpression Update(AstExpression input) + { + if (input == _input) + { + return this; + } + return new AstMedianExpression(input); + } + + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianWindowExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianWindowExpression.cs new file mode 100644 index 00000000000..fbe71c3f278 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianWindowExpression.cs @@ -0,0 +1,69 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstMedianWindowExpression : AstWindowExpression + { + private readonly AstExpression _input; + private readonly AstWindow _window; + + public AstMedianWindowExpression(AstExpression input, AstWindow window) + { + _input = Ensure.IsNotNull(input, nameof(input)); + _window = window; + } + + public AstExpression Input => _input; + + public AstWindow Window => _window; + + public override AstNodeType NodeType => AstNodeType.MedianWindowExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitMedianWindowExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$median", new BsonDocument + { + { "input", _input.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + }, + { "window", _window?.Render(), _window != null } + }; + } + + public AstMedianWindowExpression Update(AstExpression input, AstWindow window) + { + if (input == _input && window == _window) + { + return this; + } + + return new AstMedianWindowExpression(input, window); + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileExpression.cs new file mode 100644 index 00000000000..aab4920143d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileExpression.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstPercentileExpression : AstExpression + { + private readonly AstExpression _input; + private readonly AstExpression _percentiles; + + public AstPercentileExpression(AstExpression input, AstExpression percentiles) + { + _input = Ensure.IsNotNull(input, nameof(input)); + _percentiles = Ensure.IsNotNull(percentiles, nameof(percentiles)); + } + + public AstExpression Input => _input; + + public AstExpression Percentiles => _percentiles; + + public override AstNodeType NodeType => AstNodeType.PercentileExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitPercentileExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$percentile", new BsonDocument + { + { "input", _input.Render() }, + { "p", _percentiles.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + } + }; + } + + public AstPercentileExpression Update(AstExpression input, AstExpression percentiles) + { + if (input == _input && percentiles == _percentiles) + { + return this; + } + return new AstPercentileExpression(input, percentiles); + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileWindowExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileWindowExpression.cs new file mode 100644 index 00000000000..055bcfb8ef3 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileWindowExpression.cs @@ -0,0 +1,74 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstPercentileWindowExpression : AstWindowExpression + { + private readonly AstExpression _input; + private readonly AstExpression _percentiles; + private readonly AstWindow _window; + + public AstPercentileWindowExpression(AstExpression input, AstExpression percentiles, AstWindow window) + { + _input = Ensure.IsNotNull(input, nameof(input)); + _percentiles = Ensure.IsNotNull(percentiles, nameof(percentiles)); + _window = window; + } + + public AstExpression Input => _input; + + public AstExpression Percentiles => _percentiles; + + public AstWindow Window => _window; + + public override AstNodeType NodeType => AstNodeType.PercentileWindowExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitPercentileWindowExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$percentile", new BsonDocument + { + { "input", _input.Render() }, + { "p", _percentiles.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + }, + { "window", _window?.Render(), _window != null } + }; + } + + public AstPercentileWindowExpression Update(AstExpression input, AstExpression percentiles, AstWindow window) + { + if (input == _input && percentiles == _percentiles && window == _window) + { + return this; + } + + return new AstPercentileWindowExpression(input, percentiles, window); + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs index 1222d0060e9..35fc5cd922f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs @@ -199,6 +199,24 @@ public virtual AstNode VisitComparisonFilterOperation(AstComparisonFilterOperati return node; } + public virtual AstNode VisitComplexAccumulatorExpression(AstComplexAccumulatorExpression node) + { + Dictionary newArgs = null; + foreach (var kvp in node.Args) + { + var oldArg = kvp.Value; + var newArg = VisitAndConvert(oldArg); + + if (newArg != oldArg) + { + newArgs ??= new Dictionary(node.Args); + newArgs[kvp.Key] = newArg; + } + } + + return newArgs != null ? node.Update(newArgs) : node; + } + public virtual AstNode VisitComputedArrayExpression(AstComputedArrayExpression node) { return node.Update(VisitAndConvert(node.Items)); @@ -504,6 +522,16 @@ public virtual AstNode VisitMatchStage(AstMatchStage node) return node.Update(VisitAndConvert(node.Filter)); } + public virtual AstNode VisitMedianExpression(AstMedianExpression node) + { + return node.Update(VisitAndConvert(node.Input)); + } + + public virtual AstNode VisitMedianWindowExpression(AstMedianWindowExpression node) + { + return node.Update(VisitAndConvert(node.Input), node.Window); + } + public virtual AstNode VisitMergeStage(AstMergeStage node) { return node.Update(VisitAndConvert(node.Let)); @@ -559,6 +587,16 @@ public virtual AstNode VisitOutStage(AstOutStage node) return node; } + public virtual AstNode VisitPercentileExpression(AstPercentileExpression node) + { + return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.Percentiles)); + } + + public virtual AstNode VisitPercentileWindowExpression(AstPercentileWindowExpression node) + { + return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.Percentiles), node.Window); + } + public virtual AstNode VisitPickAccumulatorExpression(AstPickAccumulatorExpression node) { return node.Update(node.Operator, node.SortBy, VisitAndConvert(node.Selector), VisitAndConvert(node.N)); From 2e8f5908cff186d54898c1b9ebf1427148e87b29 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:46:17 -0400 Subject: [PATCH 04/15] Implement translators for median and percentile methods --- ...essionToAggregationExpressionTranslator.cs | 3 + ...MethodToAggregationExpressionTranslator.cs | 76 ++++++++++++++++++ ...MethodToAggregationExpressionTranslator.cs | 79 +++++++++++++++++++ ...MethodToAggregationExpressionTranslator.cs | 30 ++++++- 4 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs index ae0c1f93193..e30deb4149e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs @@ -64,8 +64,10 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC case "IsNullOrWhiteSpace": return IsNullOrWhiteSpaceMethodToAggregationExpressionTranslator.Translate(context, expression); case "IsSubsetOf": return IsSubsetOfMethodToAggregationExpressionTranslator.Translate(context, expression); case "Locf": return LocfMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Median": return MedianMethodToAggregationExpressionTranslator.Translate(context, expression); case "OfType": return OfTypeMethodToAggregationExpressionTranslator.Translate(context, expression); case "Parse": return ParseMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Percentile": return PercentileMethodToAggregationExpressionTranslator.Translate(context, expression); case "Pow": return PowMethodToAggregationExpressionTranslator.Translate(context, expression); case "Push": return PushMethodToAggregationExpressionTranslator.Translate(context, expression); case "Range": return RangeMethodToAggregationExpressionTranslator.Translate(context, expression); @@ -93,6 +95,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC case "Union": return UnionMethodToAggregationExpressionTranslator.Translate(context, expression); case "Zip": return ZipMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Acos": case "Acosh": case "Asin": diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..08ab0efd886 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -0,0 +1,76 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators +{ + internal class MedianMethodToAggregationExpressionTranslator + { + public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) + { + var method = expression.Method; + var arguments = expression.Arguments; + + if (IsMedianMethod(method)) + { + if (arguments.Count is 1 or 2) + { + var sourceExpression = arguments[0]; + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + + var inputAst = sourceTranslation.Ast; + + // Median(source, selector) + if (arguments.Count == 2) + { + var selectorLambda = (LambdaExpression)arguments[1]; + var selectorParameter = selectorLambda.Parameters[0]; + var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorContext = context.WithSymbol(selectorParameterSymbol); + var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); + + inputAst = AstExpression.Map( + input: sourceTranslation.Ast, + @as: selectorParameterSymbol.Var, + @in: selectorTranslation.Ast); + } + + var ast = AstExpression.Median(inputAst); + var serializer = BsonSerializer.LookupSerializer(expression.Type); + return new TranslatedExpression(expression, ast, serializer); + } + } + + if (WindowMethodToAggregationExpressionTranslator.CanTranslate(expression)) + { + return WindowMethodToAggregationExpressionTranslator.Translate(context, expression); + } + + throw new ExpressionNotSupportedException(expression); + } + + private static bool IsMedianMethod(MethodInfo methodInfo) + { + return methodInfo.DeclaringType == typeof(MongoEnumerable) && methodInfo.Name == "Median"; + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..8e064834779 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -0,0 +1,79 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators +{ + internal class PercentileMethodToAggregationExpressionTranslator + { + public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) + { + var method = expression.Method; + var arguments = expression.Arguments; + + if (IsPercentileMethod(method)) + { + if (arguments.Count is 2 or 3) + { + var sourceExpression = arguments[0]; + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + + AstExpression inputAst = sourceTranslation.Ast; + + // handle selector + if (arguments.Count == 3) + { + var selectorLambda = (LambdaExpression)arguments[1]; + var selectorParameter = selectorLambda.Parameters[0]; + var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorContext = context.WithSymbol(selectorParameterSymbol); + var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); + + inputAst = AstExpression.Map( + input: sourceTranslation.Ast, + @as: selectorParameterSymbol.Var, + @in: selectorTranslation.Ast); + } + + var percentilesExpression = arguments[arguments.Count - 1]; + var percentilesTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, percentilesExpression); + + var ast = AstExpression.Percentile(inputAst, percentilesTranslation.Ast); + var serializer = BsonSerializer.LookupSerializer(expression.Type); + return new TranslatedExpression(expression, ast, serializer); + } + } + + if (WindowMethodToAggregationExpressionTranslator.CanTranslate(expression)) + { + return WindowMethodToAggregationExpressionTranslator.Translate(context, expression); + } + + throw new ExpressionNotSupportedException(expression); + } + + private static bool IsPercentileMethod(MethodInfo methodInfo) + { + return methodInfo.DeclaringType == typeof(MongoEnumerable) && methodInfo.Name == "Percentile"; + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs index 9b54a5fdd18..d274c3a7ae1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs @@ -255,7 +255,7 @@ internal static class WindowMethodToAggregationExpressionTranslator public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__windowMethods); + return IsQuantileMethod(expression.Method) || expression.Method.IsOneOf(__windowMethods); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -264,7 +264,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var parameters = method.GetParameters(); var arguments = expression.Arguments.ToArray(); - if (method.IsOneOf(__windowMethods)) + if ( IsQuantileMethod(method) || method.IsOneOf(__windowMethods)) { var partitionExpression = arguments[0]; var partitionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, partitionExpression); @@ -339,6 +339,27 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } + if (IsQuantileMethod(method)) + { + ThrowIfSelectorTranslationIsNull(selectorTranslation); + AstExpression ast; + + if (method.Name == "Percentile") + { + // Get the percentiles parameter + var percentilesExpression = GetArgument(parameters, "percentiles", arguments); + var percentilesTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, percentilesExpression); + ast = AstExpression.PercentileWindowExpression(selectorTranslation.Ast, percentilesTranslation.Ast, window); + } + else + { + ast = AstExpression.MedianWindowExpression(selectorTranslation.Ast, window); + } + + var serializer = BsonSerializer.LookupSerializer(method.ReturnType); + return new TranslatedExpression(expression, ast, serializer); + } + if (method.IsOneOf(__shiftMethods)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); @@ -440,6 +461,11 @@ private static bool HasArgument(ParameterInfo[] parameters, string return false; } + private static bool IsQuantileMethod(MethodInfo method) + { + return method.DeclaringType == typeof(ISetWindowFieldsPartitionExtensions) && method.Name is "Median" or "Percentile"; + } + private static void ThrowIfSelectorTranslationIsNull(TranslatedExpression selectTranslation) { if (selectTranslation == null) From c0d00b305df3ef3b8549085fe7067ae7b0221fcf Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:46:53 -0400 Subject: [PATCH 05/15] Update grouping pipeline optimizer --- .../AstGroupingPipelineOptimizer.cs | 177 +++++++++++------- 1 file changed, 109 insertions(+), 68 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs index 5967215ee25..15f8048e403 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs @@ -404,32 +404,93 @@ unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression && public override AstNode VisitMapExpression(AstMapExpression node) { // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0" - if (node.Input is AstGetFieldExpression mapInputGetFieldExpression && - mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") && - mapInputGetFieldExpression.Input.IsRootVar()) + if (IsElementsField(node.Input)) { var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element)); var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg); - var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); - return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); + return CreateOptimizedExpression(accumulatorExpression); } return base.VisitMapExpression(node); } + public override AstNode VisitMedianExpression(AstMedianExpression node) + { + // { $median : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, method: "approximate" } } => { __agg0 : { $median : { input: element, method: "approximate" } } } + "$__agg0" + if (IsElementsField(node.Input)) + { + var accumulator = AstExpression.ComplexAccumulator( + AstComplexAccumulatorOperator.Median, + new Dictionary + { + ["input"] = _element, + ["method"] = "approximate" + }); + return CreateOptimizedExpression(accumulator); + } + + // { $median : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, method: "approximate" } } + // => { __agg0 : { $median : { input: f(x => element), method: "approximate" } } } + "$__agg0" + if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg)) + { + var accumulator = AstExpression.ComplexAccumulator( + AstComplexAccumulatorOperator.Median, + new Dictionary + { + ["input"] = rewrittenArg, + ["method"] = "approximate" + }); + return CreateOptimizedExpression(accumulator); + } + + return base.VisitMedianExpression(node); + } + + public override AstNode VisitPercentileExpression(AstPercentileExpression node) + { + // { $percentile : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, p: [...], method: "approximate" } } + // => { __agg0 : { $percentile : { input: element, p: [...], method: "approximate" } } } + "$__agg0" + if (IsElementsField(node.Input)) + { + var accumulator = AstExpression.ComplexAccumulator( + AstComplexAccumulatorOperator.Percentile, + new Dictionary + { + ["input"] = _element, + ["p"] = node.Percentiles, + ["method"] = "approximate" + }); + return CreateOptimizedExpression(accumulator); + } + + // { $percentile : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, p: [...], method: "approximate" } } + // => { __agg0 : { $percentile : { input: f(x => element), p: [...], method: "approximate" } } } + "$__agg0" + if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg)) + { + var accumulator = AstExpression.ComplexAccumulator( + AstComplexAccumulatorOperator.Percentile, + new Dictionary + { + ["input"] = rewrittenArg, + ["p"] = node.Percentiles, + ["method"] = "approximate" + }); + return CreateOptimizedExpression(accumulator); + } + + return base.VisitPercentileExpression(node); + } + public override AstNode VisitPickExpression(AstPickExpression node) { // { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } } // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0" - if (node.Source is AstGetFieldExpression getFieldExpression && - getFieldExpression.Input.IsRootVar() && - getFieldExpression.FieldName.IsStringConstant("_elements")) + if (IsElementsField(node.Source)) { var @operator = node.Operator.ToAccumulatorOperator(); var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element)); var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N); - var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); - return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); + return CreateOptimizedExpression(accumulatorExpression); } return base.VisitPickExpression(node); @@ -437,80 +498,60 @@ public override AstNode VisitPickExpression(AstPickExpression node) public override AstNode VisitUnaryExpression(AstUnaryExpression node) { - if (TryOptimizeSizeOfElements(out var optimizedExpression)) + // { $size : "$_elements" } => { __agg0 : { $sum : 1 } } + "$__agg0" + if (node.Operator == AstUnaryOperator.Size) { - return optimizedExpression; + if (node.Arg is AstGetFieldExpression argGetFieldExpression && + argGetFieldExpression.FieldName.IsStringConstant("_elements")) + { + var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1); + return CreateOptimizedExpression(accumulatorExpression); + } } - if (TryOptimizeAccumulatorOfElements(out optimizedExpression)) + // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0" + if (node.Operator.IsAccumulator(out var accumulatorOperator) && IsElementsField(node.Arg)) { - return optimizedExpression; + var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element); + return CreateOptimizedExpression(accumulatorExpression); } - if (TryOptimizeAccumulatorOfMappedElements(out optimizedExpression)) + // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" + if (node.Operator.IsAccumulator(out accumulatorOperator) && + IsMappedElementsField(node.Arg, out var mapExpression, out var rewrittenArg)) { - return optimizedExpression; + var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg); + return CreateOptimizedExpression(accumulatorExpression); } return base.VisitUnaryExpression(node); + } - bool TryOptimizeSizeOfElements(out AstExpression optimizedExpression) - { - // { $size : "$_elements" } => { __agg0 : { $sum : 1 } } + "$__agg0" - if (node.Operator == AstUnaryOperator.Size) - { - if (node.Arg is AstGetFieldExpression argGetFieldExpression && - argGetFieldExpression.FieldName.IsStringConstant("_elements")) - { - var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1); - var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); - optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); - return true; - } - } - - optimizedExpression = null; - return false; - } + private bool IsElementsField(AstExpression expression) + { + return expression is AstGetFieldExpression getFieldExpression && + getFieldExpression.FieldName.IsStringConstant("_elements") && + getFieldExpression.Input.IsRootVar(); + } - bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression) + private bool IsMappedElementsField(AstExpression expression, out AstMapExpression mapExpression, out AstExpression rewrittenArg) + { + if (expression is AstMapExpression map && IsElementsField(map.Input)) { - // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0" - if (node.Operator.IsAccumulator(out var accumulatorOperator) && - node.Arg is AstGetFieldExpression getFieldExpression && - getFieldExpression.FieldName.IsStringConstant("_elements") && - getFieldExpression.Input.IsRootVar()) - { - var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element); - var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); - optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); - return true; - } - - optimizedExpression = null; - return false; - + mapExpression = map; + rewrittenArg = (AstExpression)AstNodeReplacer.Replace(map.In, (map.As, _element)); + return true; } - bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpression) - { - // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" - if (node.Operator.IsAccumulator(out var accumulatorOperator) && - node.Arg is AstMapExpression mapExpression && - mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression && - mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") && - mapInputGetFieldExpression.Input.IsRootVar()) - { - var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element)); - var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg); - var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); - optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); - return true; - } + mapExpression = null; + rewrittenArg = null; + return false; + } - optimizedExpression = null; - return false; - } + private AstExpression CreateOptimizedExpression(AstAccumulatorExpression accumulator) + { + var fieldName = _accumulators.AddAccumulatorExpression(accumulator); + return AstExpression.GetField(AstExpression.RootVar, fieldName); } } From a0d526db7e7b06824132c537fffa4e828cc424c3 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 06:47:03 -0400 Subject: [PATCH 06/15] add tests --- src/MongoDB.Driver/Core/Misc/Feature.cs | 6 + ...essionToAggregationExpressionTranslator.cs | 1 - ...dToAggregationExpressionTranslatorTests.cs | 460 ++++++++++++++++++ ...dToAggregationExpressionTranslatorTests.cs | 377 ++++++++++++++ ...dToAggregationExpressionTranslatorTests.cs | 400 +++++++++++++++ .../AggregateGroupTranslatorTests.cs | 73 ++- 6 files changed, 1314 insertions(+), 3 deletions(-) create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs diff --git a/src/MongoDB.Driver/Core/Misc/Feature.cs b/src/MongoDB.Driver/Core/Misc/Feature.cs index 47bf7bb65f9..ba779ca092d 100644 --- a/src/MongoDB.Driver/Core/Misc/Feature.cs +++ b/src/MongoDB.Driver/Core/Misc/Feature.cs @@ -83,6 +83,7 @@ public class Feature private static readonly Feature __loookupConciseSyntax = new Feature("LoookupConciseSyntax", WireVersion.Server50); private static readonly Feature __loookupDocuments= new Feature("LoookupDocuments", WireVersion.Server60); private static readonly Feature __mmapV1StorageEngine = new Feature("MmapV1StorageEngine", WireVersion.Zero, WireVersion.Server42); + private static readonly Feature __percentileOperator = new Feature("PercentileOperator", WireVersion.Server70); private static readonly Feature __pickAccumulatorsNewIn52 = new Feature("PickAccumulatorsNewIn52", WireVersion.Server52); private static readonly Feature __rankFusionStage = new Feature("RankFusionStage", WireVersion.Server81); private static readonly Feature __regexMatch = new Feature("RegexMatch", WireVersion.Server42); @@ -401,6 +402,11 @@ public class Feature [Obsolete("This feature was removed in server version 4.2. As such, this property will be removed in a later release.")] public static Feature MmapV1StorageEngine => __mmapV1StorageEngine; + /// + /// Gets the $percentile operator added in 7.0 + /// + public static Feature PercentileOperator => __percentileOperator; + /// /// Gets the pick accumulators new in 5.2 feature. /// diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs index e30deb4149e..4dbbe32fdd7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs @@ -95,7 +95,6 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC case "Union": return UnionMethodToAggregationExpressionTranslator.Translate(context, expression); case "Zip": return ZipMethodToAggregationExpressionTranslator.Translate(context, expression); - case "Acos": case "Acosh": case "Asin": diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs new file mode 100644 index 00000000000..1921a09bcb2 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs @@ -0,0 +1,460 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq; +using MongoDB.Driver.TestHelpers; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators +{ + public class MedianMethodToAggregationExpressionTranslatorTests : LinqIntegrationTest + { + public MedianMethodToAggregationExpressionTranslatorTests(ClassFixture fixture) + : base(fixture, server => server.Supports(Feature.PercentileOperator)) // median and percentile were added in the same server version + { + } + + [Theory] + [ParameterAttributeData] + public void Median_with_decimals_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Decimals.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.Decimals.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$Decimals', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(1.0, 1.0, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_decimals_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Decimals.AsQueryable().Median(y => y * 2.0M)) : + collection.AsQueryable().Select(x => x.Decimals.Median(y => y * 2.0M)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Decimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(2.0, 2.0, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_doubles_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Doubles.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.Doubles.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$Doubles', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(1.0, 1.0, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_doubles_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Doubles.AsQueryable().Median(y => y * 2.0)) : + collection.AsQueryable().Select(x => x.Doubles.Median(y => y * 2.0)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Doubles', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(2.0, 2.0, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_floats_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Floats.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.Floats.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$Floats', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(1.0, 1.0, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_floats_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Floats.AsQueryable().Median(y => y * 2.0F)) : + collection.AsQueryable().Select(x => x.Floats.Median(y => y * 2.0F)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Floats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(2.0, 2.0, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_ints_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Ints.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.Ints.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$Ints', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(1, 1, 2); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_ints_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Ints.AsQueryable().Median(y => y * 2)) : + collection.AsQueryable().Select(x => x.Ints.Median(y => y * 2)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Ints', as : 'y', in : { $multiply : ['$$y', 2] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(2, 2, 4); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_longs_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Longs.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.Longs.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$Longs', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(1, 1, 2); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_longs_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Longs.AsQueryable().Median(y => y * 2L)) : + collection.AsQueryable().Select(x => x.Longs.Median(y => y * 2L)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Longs', as : 'y', in : { $multiply : ['$$y', NumberLong(2)] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(2, 2, 4); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_decimals_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDecimals.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.NullableDecimals.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableDecimals', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_decimals_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDecimals.AsQueryable().Median(y => y * 2.0M)) : + collection.AsQueryable().Select(x => x.NullableDecimals.Median(y => y * 2.0M)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableDecimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_doubles_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDoubles.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.NullableDoubles.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableDoubles', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_doubles_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDoubles.AsQueryable().Median(y => y * 2.0)) : + collection.AsQueryable().Select(x => x.NullableDoubles.Median(y => y * 2.0)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableDoubles', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_floats_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableFloats.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.NullableFloats.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableFloats', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 2.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_floats_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableFloats.AsQueryable().Median(y => y * 2.0F)) : + collection.AsQueryable().Select(x => x.NullableFloats.Median(y => y * 2.0F)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableFloats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 4.0); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_ints_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableInts.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.NullableInts.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableInts', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 2); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_ints_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableInts.AsQueryable().Median(y => y * 2)) : + collection.AsQueryable().Select(x => x.NullableInts.Median(y => y * 2)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableInts', as : 'y', in : { $multiply : ['$$y', 2] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 4); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_longs_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableLongs.AsQueryable().Median()) : + collection.AsQueryable().Select(x => x.NullableLongs.Median()); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableLongs', method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 2); + } + + [Theory] + [ParameterAttributeData] + public void Median_with_nullable_longs_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableLongs.AsQueryable().Median(y => y * 2L)) : + collection.AsQueryable().Select(x => x.NullableLongs.Median(y => y * 2L)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableLongs', as : 'y', in : { $multiply : ['$$y', NumberLong(2)] } } }, method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(null, null, 4); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal[] Decimals { get; set; } + public double[] Doubles { get; set; } + public float[] Floats { get; set; } + public int[] Ints { get; set; } + public long[] Longs { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal?[] NullableDecimals { get; set; } + public double?[] NullableDoubles { get; set; } + public float?[] NullableFloats { get; set; } + public int?[] NullableInts { get; set; } + public long?[] NullableLongs { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new() + { + Id = 1, + Decimals = [1.0M], + Doubles = [1.0], + Floats = [1.0F], + Ints = [1], + Longs = [1L], + NullableDecimals = [], + NullableDoubles = [], + NullableFloats = [], + NullableInts = [], + NullableLongs = [] + }, + new() + { + Id = 2, + Decimals = [1.0M, 2.0M], + Doubles = [1.0, 2.0], + Floats = [1.0F, 2.0F], + Ints = [1, 2], + Longs = [1L, 2L], + NullableDecimals = [null], + NullableDoubles = [null], + NullableFloats = [null], + NullableInts = [null], + NullableLongs = [null] + }, + new() + { + Id = 3, + Decimals = [1.0M, 2.0M, 3.0M], + Doubles = [1.0, 2.0, 3.0], + Floats = [1.0F, 2.0F, 3.0F], + Ints = [1, 2, 3], + Longs = [1L, 2L, 3L], + NullableDecimals = [null, 1.0M, 2.0M, 3.0M], + NullableDoubles = [null, 1.0, 2.0, 3.0], + NullableFloats = [null, 1.0F, 2.0F, 3.0F], + NullableInts = [null, 1, 2, 3], + NullableLongs = [null, 1L, 2L, 3L] + } + ]; + } + } +} \ No newline at end of file diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs new file mode 100644 index 00000000000..bcc939d9f7f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs @@ -0,0 +1,377 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq; +using MongoDB.Driver.TestHelpers; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators +{ + public class PercentileMethodToAggregationExpressionTranslatorTests : LinqIntegrationTest + { + public PercentileMethodToAggregationExpressionTranslatorTests(ClassFixture fixture) + : base(fixture, server => server.Supports(Feature.PercentileOperator)) + { + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_decimals_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Decimals.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Decimals.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Decimals', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0); + results[1].Should().Equal(1.0); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_decimals_multiple_percentiles_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Decimals.AsQueryable().Percentile(new[] { 0.25, 0.75 })) : + collection.AsQueryable().Select(x => x.Decimals.Percentile(new[] { 0.25, 0.75 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Decimals', p : [0.25, 0.75], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0, 1.0); + results[1].Should().Equal(1.0, 2.0); + results[2].Should().Equal(1.0, 3.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_decimals_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Decimals.AsQueryable().Percentile(y => y * 2.0M, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Decimals.Percentile(y => y * 2.0M, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Decimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(2.0); + results[1].Should().Equal(2.0); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_doubles_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Doubles.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Doubles.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Doubles', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0); + results[1].Should().Equal(1.0); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_doubles_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Doubles.AsQueryable().Percentile(y => y * 2.0, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Doubles.Percentile(y => y * 2.0, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Doubles', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(2.0); + results[1].Should().Equal(2.0); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_floats_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Floats.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Floats.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Floats', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0); + results[1].Should().Equal(1.0); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_floats_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Floats.AsQueryable().Percentile(y => y * 2.0F, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Floats.Percentile(y => y * 2.0F, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Floats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(2.0); + results[1].Should().Equal(2.0); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_ints_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Ints.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Ints.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Ints', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0); + results[1].Should().Equal(1.0); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_ints_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Ints.AsQueryable().Percentile(y => y * 2, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Ints.Percentile(y => y * 2, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Ints', as : 'y', in : { $multiply : ['$$y', 2] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(2.0); + results[1].Should().Equal(2.0); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_longs_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Longs.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Longs.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Longs', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0); + results[1].Should().Equal(1.0); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_longs_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.Longs.AsQueryable().Percentile(y => y * 2L, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.Longs.Percentile(y => y * 2L, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Longs', as : 'y', in : { $multiply : ['$$y', NumberLong(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(2.0); + results[1].Should().Equal(2.0); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_decimals_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDecimals.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableDecimals.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableDecimals', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_decimals_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDecimals.AsQueryable().Percentile(y => y * 2.0M, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableDecimals.Percentile(y => y * 2.0M, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableDecimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(4.0); + } + + [Fact] + public void Percentile_with_list_input_should_work() + { + var collection = Fixture.Collection; + var percentiles = new List { 0.25, 0.5, 0.75 }; + + var queryable = collection.AsQueryable().Select(x => x.Doubles.Percentile(percentiles)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Doubles', p : [0.25, 0.5, 0.75], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal(1.0, 1.0, 1.0); + results[1].Should().Equal(1.0, 1.0, 2.0); + results[2].Should().Equal(1.0, 2.0, 3.0); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal[] Decimals { get; set; } + public double[] Doubles { get; set; } + public float[] Floats { get; set; } + public int[] Ints { get; set; } + public long[] Longs { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal?[] NullableDecimals { get; set; } + public double?[] NullableDoubles { get; set; } + public float?[] NullableFloats { get; set; } + public int?[] NullableInts { get; set; } + public long?[] NullableLongs { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new() + { + Id = 1, + Decimals = [1.0M], + Doubles = [1.0], + Floats = [1.0F], + Ints = [1], + Longs = [1L], + NullableDecimals = [], + NullableDoubles = [], + NullableFloats = [], + NullableInts = [], + NullableLongs = [] + }, + new() + { + Id = 2, + Decimals = [1.0M, 2.0M], + Doubles = [1.0, 2.0], + Floats = [1.0F, 2.0F], + Ints = [1, 2], + Longs = [1L, 2L], + NullableDecimals = [null], + NullableDoubles = [null], + NullableFloats = [null], + NullableInts = [null], + NullableLongs = [null] + }, + new() + { + Id = 3, + Decimals = [1.0M, 2.0M, 3.0M], + Doubles = [1.0, 2.0, 3.0], + Floats = [1.0F, 2.0F, 3.0F], + Ints = [1, 2, 3], + Longs = [1L, 2L, 3L], + NullableDecimals = [null, 1.0M, 2.0M, 3.0M], + NullableDoubles = [null, 1.0, 2.0, 3.0], + NullableFloats = [null, 1.0F, 2.0F, 3.0F], + NullableInts = [null, 1, 2, 3], + NullableLongs = [null, 1L, 2L, 3L] + } + ]; + } + } +} \ No newline at end of file diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs index 59cdd9c6539..ac5ffee705e 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs @@ -1488,6 +1488,184 @@ public void Translate_should_return_expected_result_for_Max() } } + [Fact] + public void Translate_should_return_expected_result_for_Median_with_Decimal() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.DecimalField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$DecimalField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_Double() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.DoubleField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$DoubleField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_Int32() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.Int32Field, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$Int32Field', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_Int64() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.Int64Field, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$Int64Field', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_nullable_Decimal() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.NullableDecimalField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$NullableDecimalField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_nullable_Double() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.NullableDoubleField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$NullableDoubleField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_nullable_Int32() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.NullableInt32Field, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$NullableInt32Field', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_nullable_Int64() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.NullableInt64Field, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$NullableInt64Field', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_window() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields( + partitionBy: x => 1, + sortBy: Builders.Sort.Ascending(x => x.Id), + output: p => new { + Result = p.Median(x => x.Int32Field, DocumentsWindow.Create(-1, 1)) + }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] + { + "{ $setWindowFields : { partitionBy : 1, sortBy : { _id : 1 }, output : { Result : { $median : { input : '$Int32Field', method : 'approximate' }, window : { documents : [-1, 1] } } } } }" + }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + results[0]["Result"].AsDouble.Should().Be(1.0); + results[1]["Result"].AsDouble.Should().Be(2.0); + results[2]["Result"].AsDouble.Should().Be(2.0); + } + [Fact] public void Translate_should_return_expected_result_for_Min() { @@ -1507,6 +1685,228 @@ public void Translate_should_return_expected_result_for_Min() } } + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_Decimal() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.DecimalField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$DecimalField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_Double() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.DoubleField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$DoubleField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_Int32() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.Int32Field, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$Int32Field', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_Int64() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.Int64Field, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$Int64Field', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_nullable_Decimal() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.NullableDecimalField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$NullableDecimalField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_nullable_Double() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.NullableDoubleField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$NullableDoubleField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_nullable_Int32() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.NullableInt32Field, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$NullableInt32Field', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_nullable_Int64() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.NullableInt64Field, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$NullableInt64Field', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_multiple_percentiles() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.Int32Field, new[] { 0.25, 0.75 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$Int32Field', p : [0.25, 0.75], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + var array = result["Result"].AsBsonArray; + array[0].AsDouble.Should().Be(1.0); + array[1].AsDouble.Should().Be(3.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_window() + { + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields( + partitionBy: x => 1, + sortBy: Builders.Sort.Ascending(x => x.Id), + output: p => new { + Result = p.Percentile(x => x.Int32Field, new[] { 0.5 }, DocumentsWindow.Create(-1, 1)) + }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] + { + "{ $setWindowFields : { partitionBy : 1, sortBy : { _id : 1 }, output : { Result : { $percentile : { input : '$Int32Field', p : [0.5], method : 'approximate' }, window : { documents : [-1, 1] } } } } }" + }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + results[0]["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + results[1]["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + results[2]["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_List_input() + { + var collection = Fixture.Collection; + var percentiles = new List { 0.25, 0.5, 0.75 }; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.Int32Field, percentiles, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$Int32Field', p : [0.25, 0.5, 0.75], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + var array = result["Result"].AsBsonArray; + array[0].AsDouble.Should().Be(1.0); + array[1].AsDouble.Should().Be(2.0); + array[2].AsDouble.Should().Be(3.0); + } + } + [Fact] public void Translate_should_return_expected_result_for_Push() { diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index 45bbf7067af..f4e0979708b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -20,6 +20,7 @@ using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.Serialization; +using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; using MongoDB.Driver.Linq; using MongoDB.Driver.Linq.Linq3Implementation.Ast; @@ -164,6 +165,66 @@ public void Should_translate_count_with_a_predicate() result.Value.Result.Should().Be(1); } + [Fact] + public void Should_translate_median_with_embedded_projector() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var result = Group(x => x.A, g => new { Result = g.Median(x=> x.C.E.F) }); + + AssertStages( + result.Stages, + "{ $group : { _id : '$A', __agg0 : { $median : { input : '$C.E.F', method : 'approximate' } } } }", + "{ $project : { Result : '$__agg0', _id : 0 } }"); + + result.Value.Result.Should().Be(111); + } + + [Fact] + public void Should_translate_median_with_selected_projector() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var result = Group(x => x.A, g => new { Result = g.Select(x => x.C.E.F).Median() }); + + AssertStages( + result.Stages, + "{ $group : { _id : '$A', __agg0 : { $median : { input : '$C.E.F', method : 'approximate' } } } }", + "{ $project : { Result : '$__agg0', _id : 0 } }"); + + result.Value.Result.Should().Be(111); + } + + [Fact] + public void Should_translate_percentile_with_embedded_projector() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var result = Group(x => x.A, g => new { Result = g.Percentile(x => x.C.E.F, new[] { 0.5 }) }); + + AssertStages( + result.Stages, + "{ $group : { _id : '$A', __agg0 : { $percentile : { input : '$C.E.F', p : [0.5], method : 'approximate' } } } }", + "{ $project : { Result : '$__agg0', _id : 0 } }"); + + result.Value.Result.Should().Equal(111.0); + } + + [Fact] + public void Should_translate_percentile_with_selected_projector() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var result = Group(x => x.A, g => new { Result = g.Select(x => x.C.E.F).Percentile(new[] { 0.5 }) }); + + AssertStages( + result.Stages, + "{ $group : { _id : '$A', __agg0 : { $percentile : { input : '$C.E.F', p : [0.5], method : 'approximate' } } } }", + "{ $project : { Result : '$__agg0', _id : 0 } }"); + + result.Value.Result.Should().Equal(111.0); + } + [Fact] public void Should_translate_where_with_a_predicate_and_count() { @@ -481,7 +542,9 @@ public void Should_translate_complex_selector() First = g.First().B, Last = g.Last().K, Min = g.Min(x => x.C.E.F + x.C.E.H), - Max = g.Max(x => x.C.E.F + x.C.E.H) + Max = g.Max(x => x.C.E.F + x.C.E.H), + Median = g.Median(x => x.C.E.F + x.C.E.H), + Percentile = g.Percentile(x => x.C.E.F + x.C.E.H, new[] { 0.95 }) }); AssertStages( @@ -495,7 +558,9 @@ public void Should_translate_complex_selector() __agg2 : { $first : '$B' }, __agg3 : { $last : '$K' }, __agg4 : { $min : { $add : ['$C.E.F', '$C.E.H'] } }, - __agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } } + __agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } }, + __agg6 : { $median : { input : { $add : ['$C.E.F', '$C.E.H'] }, method : 'approximate' } }, + __agg7 : { $percentile : { input : { $add : ['$C.E.F', '$C.E.H'] }, p : [0.95], method : 'approximate' } } } }", @" @@ -507,6 +572,8 @@ public void Should_translate_complex_selector() Last : '$__agg3', Min : '$__agg4', Max : '$__agg5', + Median : '$__agg6', + Percentile : '$__agg7', _id : 0 } }"); @@ -517,6 +584,8 @@ public void Should_translate_complex_selector() result.Value.Last.Should().Be(false); result.Value.Min.Should().Be(333); result.Value.Max.Should().Be(333); + result.Value.Median.Should().Be(333); + result.Value.Percentile.Should().Equal(333); } [Fact] From 4bc77a6edeadef52dd18ae1b5679a93824312912 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 11:18:53 -0400 Subject: [PATCH 07/15] fix some tests failures --- ...dToAggregationExpressionTranslatorTests.cs | 40 +++++++++++++++++ .../AggregateGroupTranslatorTests.cs | 44 +++++++++++++++---- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs index ac5ffee705e..31a511b5a67 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs @@ -1491,6 +1491,8 @@ public void Translate_should_return_expected_result_for_Max() [Fact] public void Translate_should_return_expected_result_for_Median_with_Decimal() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1510,6 +1512,8 @@ public void Translate_should_return_expected_result_for_Median_with_Decimal() [Fact] public void Translate_should_return_expected_result_for_Median_with_Double() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1529,6 +1533,8 @@ public void Translate_should_return_expected_result_for_Median_with_Double() [Fact] public void Translate_should_return_expected_result_for_Median_with_Int32() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1548,6 +1554,8 @@ public void Translate_should_return_expected_result_for_Median_with_Int32() [Fact] public void Translate_should_return_expected_result_for_Median_with_Int64() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1567,6 +1575,8 @@ public void Translate_should_return_expected_result_for_Median_with_Int64() [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Decimal() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1586,6 +1596,8 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Dec [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Double() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1605,6 +1617,8 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Dou [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Int32() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1624,6 +1638,8 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Int [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Int64() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1643,6 +1659,8 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Int [Fact] public void Translate_should_return_expected_result_for_Median_with_window() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1688,6 +1706,8 @@ public void Translate_should_return_expected_result_for_Min() [Fact] public void Translate_should_return_expected_result_for_Percentile_with_Decimal() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1707,6 +1727,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_Decimal( [Fact] public void Translate_should_return_expected_result_for_Percentile_with_Double() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1726,6 +1748,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_Double() [Fact] public void Translate_should_return_expected_result_for_Percentile_with_Int32() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1745,6 +1769,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_Int32() [Fact] public void Translate_should_return_expected_result_for_Percentile_with_Int64() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1764,6 +1790,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_Int64() [Fact] public void Translate_should_return_expected_result_for_Percentile_with_nullable_Decimal() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1783,6 +1811,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable [Fact] public void Translate_should_return_expected_result_for_Percentile_with_nullable_Double() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1802,6 +1832,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable [Fact] public void Translate_should_return_expected_result_for_Percentile_with_nullable_Int32() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1821,6 +1853,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable [Fact] public void Translate_should_return_expected_result_for_Percentile_with_nullable_Int64() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1840,6 +1874,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable [Fact] public void Translate_should_return_expected_result_for_Percentile_with_multiple_percentiles() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1861,6 +1897,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_multiple [Fact] public void Translate_should_return_expected_result_for_Percentile_with_window() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var aggregate = collection.Aggregate() @@ -1887,6 +1925,8 @@ public void Translate_should_return_expected_result_for_Percentile_with_window() [Fact] public void Translate_should_return_expected_result_for_Percentile_with_List_input() { + RequireServer.Check().Supports(Feature.PercentileOperator); + var collection = Fixture.Collection; var percentiles = new List { 0.25, 0.5, 0.75 }; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index f4e0979708b..317b9068ef1 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -542,9 +542,7 @@ public void Should_translate_complex_selector() First = g.First().B, Last = g.Last().K, Min = g.Min(x => x.C.E.F + x.C.E.H), - Max = g.Max(x => x.C.E.F + x.C.E.H), - Median = g.Median(x => x.C.E.F + x.C.E.H), - Percentile = g.Percentile(x => x.C.E.F + x.C.E.H, new[] { 0.95 }) + Max = g.Max(x => x.C.E.F + x.C.E.H) }); AssertStages( @@ -558,9 +556,7 @@ public void Should_translate_complex_selector() __agg2 : { $first : '$B' }, __agg3 : { $last : '$K' }, __agg4 : { $min : { $add : ['$C.E.F', '$C.E.H'] } }, - __agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } }, - __agg6 : { $median : { input : { $add : ['$C.E.F', '$C.E.H'] }, method : 'approximate' } }, - __agg7 : { $percentile : { input : { $add : ['$C.E.F', '$C.E.H'] }, p : [0.95], method : 'approximate' } } + __agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } } } }", @" @@ -571,9 +567,7 @@ public void Should_translate_complex_selector() First : '$__agg2', Last : '$__agg3', Min : '$__agg4', - Max : '$__agg5', - Median : '$__agg6', - Percentile : '$__agg7', + Max : '$__agg5' _id : 0 } }"); @@ -584,6 +578,38 @@ public void Should_translate_complex_selector() result.Value.Last.Should().Be(false); result.Value.Min.Should().Be(333); result.Value.Max.Should().Be(333); + } + + [Fact] + public void Should_translate_complex_selector_with_quantile_methods() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var result = Group(x => x.A, g => new + { + Median = g.Median(x => x.C.E.F + x.C.E.H), + Percentile = g.Percentile(x => x.C.E.F + x.C.E.H, new[] { 0.95 }) + }); + + AssertStages( + result.Stages, + @" + { + $group : { + _id : '$A', + __agg0 : { $median : { input : { $add : ['$C.E.F', '$C.E.H'] }, method : 'approximate' } }, + __agg1 : { $percentile : { input : { $add : ['$C.E.F', '$C.E.H'] }, p : [0.95], method : 'approximate' } } + } + }", + @" + { + $project : { + Median : '$__agg0', + Percentile : '$__agg1', + _id : 0 + } + }"); + result.Value.Median.Should().Be(333); result.Value.Percentile.Should().Equal(333); } From 534f6504a8d5adf16525bf6a67a0847059507d32 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 16:48:24 -0400 Subject: [PATCH 08/15] rearrange the methods in MongoEnumerable.cs --- src/MongoDB.Driver/Linq/MongoEnumerable.cs | 200 ++++++++++----------- 1 file changed, 100 insertions(+), 100 deletions(-) diff --git a/src/MongoDB.Driver/Linq/MongoEnumerable.cs b/src/MongoDB.Driver/Linq/MongoEnumerable.cs index f3ba9c7d4c5..f54586bff8c 100644 --- a/src/MongoDB.Driver/Linq/MongoEnumerable.cs +++ b/src/MongoDB.Driver/Linq/MongoEnumerable.cs @@ -239,11 +239,9 @@ public static IEnumerable MaxN( /// /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static double Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -251,11 +249,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static double? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -263,11 +259,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static double Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -275,11 +269,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static double? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -287,11 +279,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static double Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -299,11 +289,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static double? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -311,11 +299,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static double Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -323,11 +309,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static double? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -335,11 +319,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static double Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -347,11 +329,9 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The type of the elements in the source sequence. - /// A sequence of values to calculate the median of. - /// A transform function to apply to each element. + /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static double? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -359,9 +339,11 @@ public static double Median(this IEnumerable source, Func /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source) + public static double Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -369,9 +351,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source) + public static double? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -379,9 +363,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source) + public static double Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -389,9 +375,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source) + public static double? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -399,9 +387,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source) + public static double Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -409,9 +399,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source) + public static double? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -419,9 +411,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source) + public static double Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -429,9 +423,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source) + public static double? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -439,9 +435,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source) + public static double Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -449,9 +447,11 @@ public static double Median(this IEnumerable source) /// /// Computes the median of a sequence of values. /// - /// The sequence of values. + /// The type of the elements in the source sequence. + /// A sequence of values to calculate the median of. + /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source) + public static double? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -496,12 +496,10 @@ public static IEnumerable MinN( /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -509,12 +507,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -522,12 +518,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -535,12 +529,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -548,12 +540,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -561,12 +551,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -574,12 +562,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -587,12 +573,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -600,12 +584,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -613,12 +595,10 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// - /// The type of the elements of . /// A sequence of values to calculate the percentiles of. - /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -626,10 +606,12 @@ public static double[] Percentile(this IEnumerable source, Fun /// /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -637,10 +619,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -648,10 +632,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -659,10 +645,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -670,10 +658,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -681,10 +671,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -692,10 +684,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -703,10 +697,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -714,10 +710,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -725,10 +723,12 @@ public static double[] Percentile(this IEnumerable source, IEnumerable< /// /// Computes multiple percentiles of a sequence of values. /// + /// The type of the elements of . /// A sequence of values to calculate the percentiles of. + /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } From 3951103d017ddd97138bd7d176aa543cc48cd40e Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 17:39:31 -0400 Subject: [PATCH 09/15] add median feature and undo accidental change --- src/MongoDB.Driver/Core/Misc/Feature.cs | 6 ++++++ .../MedianMethodToAggregationExpressionTranslatorTests.cs | 2 +- .../Translators/AggregateGroupTranslatorTests.cs | 6 +++--- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/MongoDB.Driver/Core/Misc/Feature.cs b/src/MongoDB.Driver/Core/Misc/Feature.cs index ba779ca092d..d2a3d86e88a 100644 --- a/src/MongoDB.Driver/Core/Misc/Feature.cs +++ b/src/MongoDB.Driver/Core/Misc/Feature.cs @@ -83,6 +83,7 @@ public class Feature private static readonly Feature __loookupConciseSyntax = new Feature("LoookupConciseSyntax", WireVersion.Server50); private static readonly Feature __loookupDocuments= new Feature("LoookupDocuments", WireVersion.Server60); private static readonly Feature __mmapV1StorageEngine = new Feature("MmapV1StorageEngine", WireVersion.Zero, WireVersion.Server42); + private static readonly Feature __medianOperator = new Feature("MedianOperator", WireVersion.Server70); private static readonly Feature __percentileOperator = new Feature("PercentileOperator", WireVersion.Server70); private static readonly Feature __pickAccumulatorsNewIn52 = new Feature("PickAccumulatorsNewIn52", WireVersion.Server52); private static readonly Feature __rankFusionStage = new Feature("RankFusionStage", WireVersion.Server81); @@ -402,6 +403,11 @@ public class Feature [Obsolete("This feature was removed in server version 4.2. As such, this property will be removed in a later release.")] public static Feature MmapV1StorageEngine => __mmapV1StorageEngine; + /// + /// Gets the $median operator added in 7.0 + /// + public static Feature MedianOperator => __medianOperator; + /// /// Gets the $percentile operator added in 7.0 /// diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs index 1921a09bcb2..0aaaefce86b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs @@ -29,7 +29,7 @@ namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionTo public class MedianMethodToAggregationExpressionTranslatorTests : LinqIntegrationTest { public MedianMethodToAggregationExpressionTranslatorTests(ClassFixture fixture) - : base(fixture, server => server.Supports(Feature.PercentileOperator)) // median and percentile were added in the same server version + : base(fixture, server => server.Supports(Feature.MedianOperator)) { } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index 317b9068ef1..a8f7428079b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -168,7 +168,7 @@ public void Should_translate_count_with_a_predicate() [Fact] public void Should_translate_median_with_embedded_projector() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var result = Group(x => x.A, g => new { Result = g.Median(x=> x.C.E.F) }); @@ -183,7 +183,7 @@ public void Should_translate_median_with_embedded_projector() [Fact] public void Should_translate_median_with_selected_projector() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var result = Group(x => x.A, g => new { Result = g.Select(x => x.C.E.F).Median() }); @@ -567,7 +567,7 @@ public void Should_translate_complex_selector() First : '$__agg2', Last : '$__agg3', Min : '$__agg4', - Max : '$__agg5' + Max : '$__agg5', _id : 0 } }"); From f9b312a4e05aae7a9519f9c07661321224522bae Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 19:17:25 -0400 Subject: [PATCH 10/15] use reflection pattern to recognize methods in translators --- .../Reflection/EnumerableMethod.cs | 120 ++++++++++++++++++ .../Reflection/WindowMethod.cs | 61 +++++++++ ...MethodToAggregationExpressionTranslator.cs | 90 ++++++++----- ...MethodToAggregationExpressionTranslator.cs | 94 +++++++++----- ...MethodToAggregationExpressionTranslator.cs | 55 ++++++-- 5 files changed, 350 insertions(+), 70 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index 0ae3e99ca4a..c3327365d26 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -112,6 +112,26 @@ internal static class EnumerableMethod private static readonly MethodInfo __maxSingle; private static readonly MethodInfo __maxSingleWithSelector; private static readonly MethodInfo __maxWithSelector; + private static readonly MethodInfo __medianDecimal; + private static readonly MethodInfo __medianDecimalWithSelector; + private static readonly MethodInfo __medianDouble; + private static readonly MethodInfo __medianDoubleWithSelector; + private static readonly MethodInfo __medianInt32; + private static readonly MethodInfo __medianInt32WithSelector; + private static readonly MethodInfo __medianInt64; + private static readonly MethodInfo __medianInt64WithSelector; + private static readonly MethodInfo __medianNullableDecimal; + private static readonly MethodInfo __medianNullableDecimalWithSelector; + private static readonly MethodInfo __medianNullableDouble; + private static readonly MethodInfo __medianNullableDoubleWithSelector; + private static readonly MethodInfo __medianNullableInt32; + private static readonly MethodInfo __medianNullableInt32WithSelector; + private static readonly MethodInfo __medianNullableInt64; + private static readonly MethodInfo __medianNullableInt64WithSelector; + private static readonly MethodInfo __medianNullableSingle; + private static readonly MethodInfo __medianNullableSingleWithSelector; + private static readonly MethodInfo __medianSingle; + private static readonly MethodInfo __medianSingleWithSelector; private static readonly MethodInfo __min; private static readonly MethodInfo __minDecimal; private static readonly MethodInfo __minDecimalWithSelector; @@ -139,6 +159,26 @@ internal static class EnumerableMethod private static readonly MethodInfo __ofType; private static readonly MethodInfo __orderBy; private static readonly MethodInfo __orderByDescending; + private static readonly MethodInfo __percentileDecimal; + private static readonly MethodInfo __percentileDecimalWithSelector; + private static readonly MethodInfo __percentileDouble; + private static readonly MethodInfo __percentileDoubleWithSelector; + private static readonly MethodInfo __percentileInt32; + private static readonly MethodInfo __percentileInt32WithSelector; + private static readonly MethodInfo __percentileInt64; + private static readonly MethodInfo __percentileInt64WithSelector; + private static readonly MethodInfo __percentileNullableDecimal; + private static readonly MethodInfo __percentileNullableDecimalWithSelector; + private static readonly MethodInfo __percentileNullableDouble; + private static readonly MethodInfo __percentileNullableDoubleWithSelector; + private static readonly MethodInfo __percentileNullableInt32; + private static readonly MethodInfo __percentileNullableInt32WithSelector; + private static readonly MethodInfo __percentileNullableInt64; + private static readonly MethodInfo __percentileNullableInt64WithSelector; + private static readonly MethodInfo __percentileNullableSingle; + private static readonly MethodInfo __percentileNullableSingleWithSelector; + private static readonly MethodInfo __percentileSingle; + private static readonly MethodInfo __percentileSingleWithSelector; private static readonly MethodInfo __prepend; private static readonly MethodInfo __range; private static readonly MethodInfo __repeat; @@ -279,6 +319,26 @@ static EnumerableMethod() __maxSingle = ReflectionInfo.Method((IEnumerable source) => source.Max()); __maxSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Max(selector)); __maxWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Max(selector)); + __medianDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); __min = ReflectionInfo.Method((IEnumerable source) => source.Min()); __minDecimal = ReflectionInfo.Method((IEnumerable source) => source.Min()); __minDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Min(selector)); @@ -306,6 +366,26 @@ static EnumerableMethod() __ofType = ReflectionInfo.Method((IEnumerable source) => source.OfType()); __orderBy = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.OrderBy(keySelector)); __orderByDescending = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.OrderByDescending(keySelector)); + __percentileDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __prepend = ReflectionInfo.Method((IEnumerable source, object element) => source.Prepend(element)); __range = ReflectionInfo.Method((int start, int count) => Enumerable.Range(start, count)); __repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count)); @@ -445,6 +525,26 @@ static EnumerableMethod() public static MethodInfo MaxSingle => __maxSingle; public static MethodInfo MaxSingleWithSelector => __maxSingleWithSelector; public static MethodInfo MaxWithSelector => __maxWithSelector; + public static MethodInfo MedianDecimal => __medianDecimal; + public static MethodInfo MedianDecimalWithSelector => __medianDecimalWithSelector; + public static MethodInfo MedianDouble => __medianDouble; + public static MethodInfo MedianDoubleWithSelector => __medianDoubleWithSelector; + public static MethodInfo MedianInt32 => __medianInt32; + public static MethodInfo MedianInt32WithSelector => __medianInt32WithSelector; + public static MethodInfo MedianInt64 => __medianInt64; + public static MethodInfo MedianInt64WithSelector => __medianInt64WithSelector; + public static MethodInfo MedianNullableDecimal => __medianNullableDecimal; + public static MethodInfo MedianNullableDecimalWithSelector => __medianNullableDecimalWithSelector; + public static MethodInfo MedianNullableDouble => __medianNullableDouble; + public static MethodInfo MedianNullableDoubleWithSelector => __medianNullableDoubleWithSelector; + public static MethodInfo MedianNullableInt32 => __medianNullableInt32; + public static MethodInfo MedianNullableInt32WithSelector => __medianNullableInt32WithSelector; + public static MethodInfo MedianNullableInt64 => __medianNullableInt64; + public static MethodInfo MedianNullableInt64WithSelector => __medianNullableInt64WithSelector; + public static MethodInfo MedianNullableSingle => __medianNullableSingle; + public static MethodInfo MedianNullableSingleWithSelector => __medianNullableSingleWithSelector; + public static MethodInfo MedianSingle => __medianSingle; + public static MethodInfo MedianSingleWithSelector => __medianSingleWithSelector; public static MethodInfo Min => __min; public static MethodInfo MinDecimal => __minDecimal; public static MethodInfo MinDecimalWithSelector => __minDecimalWithSelector; @@ -472,6 +572,26 @@ static EnumerableMethod() public static MethodInfo OfType => __ofType; public static MethodInfo OrderBy => __orderBy; public static MethodInfo OrderByDescending => __orderByDescending; + public static MethodInfo PercentileDecimal => __percentileDecimal; + public static MethodInfo PercentileDecimalWithSelector => __percentileDecimalWithSelector; + public static MethodInfo PercentileDouble => __percentileDouble; + public static MethodInfo PercentileDoubleWithSelector => __percentileDoubleWithSelector; + public static MethodInfo PercentileInt32 => __percentileInt32; + public static MethodInfo PercentileInt32WithSelector => __percentileInt32WithSelector; + public static MethodInfo PercentileInt64 => __percentileInt64; + public static MethodInfo PercentileInt64WithSelector => __percentileInt64WithSelector; + public static MethodInfo PercentileNullableDecimal => __percentileNullableDecimal; + public static MethodInfo PercentileNullableDecimalWithSelector => __percentileNullableDecimalWithSelector; + public static MethodInfo PercentileNullableDouble => __percentileNullableDouble; + public static MethodInfo PercentileNullableDoubleWithSelector => __percentileNullableDoubleWithSelector; + public static MethodInfo PercentileNullableInt32 => __percentileNullableInt32; + public static MethodInfo PercentileNullableInt32WithSelector => __percentileNullableInt32WithSelector; + public static MethodInfo PercentileNullableInt64 => __percentileNullableInt64; + public static MethodInfo PercentileNullableInt64WithSelector => __percentileNullableInt64WithSelector; + public static MethodInfo PercentileNullableSingle => __percentileNullableSingle; + public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; + public static MethodInfo PercentileSingle => __percentileSingle; + public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; public static MethodInfo Prepend => __prepend; public static MethodInfo Range => __range; public static MethodInfo Repeat => __repeat; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs index 374caf1f787..693e8762269 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection @@ -84,7 +85,27 @@ internal static class WindowMethod private static readonly MethodInfo __last; private static readonly MethodInfo __locf; private static readonly MethodInfo __max; + private static readonly MethodInfo __medianWithDecimal; + private static readonly MethodInfo __medianWithDouble; + private static readonly MethodInfo __medianWithInt32; + private static readonly MethodInfo __medianWithInt64; + private static readonly MethodInfo __medianWithNullableDecimal; + private static readonly MethodInfo __medianWithNullableDouble; + private static readonly MethodInfo __medianWithNullableInt32; + private static readonly MethodInfo __medianWithNullableInt64; + private static readonly MethodInfo __medianWithNullableSingle; + private static readonly MethodInfo __medianWithSingle; private static readonly MethodInfo __min; + private static readonly MethodInfo __percentileWithDecimal; + private static readonly MethodInfo __percentileWithDouble; + private static readonly MethodInfo __percentileWithInt32; + private static readonly MethodInfo __percentileWithInt64; + private static readonly MethodInfo __percentileWithNullableDecimal; + private static readonly MethodInfo __percentileWithNullableDouble; + private static readonly MethodInfo __percentileWithNullableInt32; + private static readonly MethodInfo __percentileWithNullableInt64; + private static readonly MethodInfo __percentileWithNullableSingle; + private static readonly MethodInfo __percentileWithSingle; private static readonly MethodInfo __push; private static readonly MethodInfo __rank; private static readonly MethodInfo __shift; @@ -186,7 +207,27 @@ static WindowMethod() __last = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Last(selector, window)); __locf = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Locf(selector, window)); __max = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Max(selector, window)); + __medianWithDecimal = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithDouble = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithInt32 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithNullableDecimal = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithNullableDouble = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithNullableInt32 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithNullableInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithNullableSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); + __medianWithSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Median(selector, window)); __min = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Min(selector, window)); + __percentileWithDecimal = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithDouble = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithInt32 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithNullableDecimal = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithNullableDouble = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithNullableInt32 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithNullableInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithNullableSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); + __percentileWithSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window) => partition.Percentile(selector, percentiles, window)); __push = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Push(selector, window)); __rank = ReflectionInfo.Method((ISetWindowFieldsPartition partition) => partition.Rank()); __shift = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, int by) => partition.Shift(selector, by)); @@ -287,7 +328,27 @@ static WindowMethod() public static MethodInfo Last => __last; public static MethodInfo Locf => __locf; public static MethodInfo Max => __max; + public static MethodInfo MedianWithDecimal => __medianWithDecimal; + public static MethodInfo MedianWithDouble => __medianWithDouble; + public static MethodInfo MedianWithInt32 => __medianWithInt32; + public static MethodInfo MedianWithInt64 => __medianWithInt64; + public static MethodInfo MedianWithNullableDecimal => __medianWithNullableDecimal; + public static MethodInfo MedianWithNullableDouble => __medianWithNullableDouble; + public static MethodInfo MedianWithNullableInt32 => __medianWithNullableInt32; + public static MethodInfo MedianWithNullableInt64 => __medianWithNullableInt64; + public static MethodInfo MedianWithNullableSingle => __medianWithNullableSingle; + public static MethodInfo MedianWithSingle => __medianWithSingle; public static MethodInfo Min => __min; + public static MethodInfo PercentileWithDecimal => __percentileWithDecimal; + public static MethodInfo PercentileWithDouble => __percentileWithDouble; + public static MethodInfo PercentileWithInt32 => __percentileWithInt32; + public static MethodInfo PercentileWithInt64 => __percentileWithInt64; + public static MethodInfo PercentileWithNullableDecimal => __percentileWithNullableDecimal; + public static MethodInfo PercentileWithNullableDouble => __percentileWithNullableDouble; + public static MethodInfo PercentileWithNullableInt32 => __percentileWithNullableInt32; + public static MethodInfo PercentileWithNullableInt64 => __percentileWithNullableInt64; + public static MethodInfo PercentileWithNullableSingle => __percentileWithNullableSingle; + public static MethodInfo PercentileWithSingle => __percentileWithSingle; public static MethodInfo Push => __push; public static MethodInfo Rank => __rank; public static MethodInfo Shift => __shift; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs index 08ab0efd886..fcede33cd25 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -18,46 +18,81 @@ using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { internal class MedianMethodToAggregationExpressionTranslator { + private static readonly MethodInfo[] __medianMethods = + { + EnumerableMethod.MedianDecimal, + EnumerableMethod.MedianDecimalWithSelector, + EnumerableMethod.MedianDouble, + EnumerableMethod.MedianDoubleWithSelector, + EnumerableMethod.MedianInt32, + EnumerableMethod.MedianInt32WithSelector, + EnumerableMethod.MedianInt64, + EnumerableMethod.MedianInt64WithSelector, + EnumerableMethod.MedianNullableDecimal, + EnumerableMethod.MedianNullableDecimalWithSelector, + EnumerableMethod.MedianNullableDouble, + EnumerableMethod.MedianNullableDoubleWithSelector, + EnumerableMethod.MedianNullableInt32, + EnumerableMethod.MedianNullableInt32WithSelector, + EnumerableMethod.MedianNullableInt64, + EnumerableMethod.MedianNullableInt64WithSelector, + EnumerableMethod.MedianNullableSingle, + EnumerableMethod.MedianNullableSingleWithSelector, + EnumerableMethod.MedianSingle, + EnumerableMethod.MedianSingleWithSelector + }; + + private static readonly MethodInfo[] __medianWithSelectorMethods = + { + EnumerableMethod.MedianDecimalWithSelector, + EnumerableMethod.MedianDoubleWithSelector, + EnumerableMethod.MedianInt32WithSelector, + EnumerableMethod.MedianInt64WithSelector, + EnumerableMethod.MedianNullableDecimalWithSelector, + EnumerableMethod.MedianNullableDoubleWithSelector, + EnumerableMethod.MedianNullableInt32WithSelector, + EnumerableMethod.MedianNullableInt64WithSelector, + EnumerableMethod.MedianNullableSingleWithSelector, + EnumerableMethod.MedianSingleWithSelector + }; + public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (IsMedianMethod(method)) + if (method.IsOneOf(__medianMethods)) { - if (arguments.Count is 1 or 2) - { - var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); - NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - - var inputAst = sourceTranslation.Ast; + var sourceExpression = arguments[0]; + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - // Median(source, selector) - if (arguments.Count == 2) - { - var selectorLambda = (LambdaExpression)arguments[1]; - var selectorParameter = selectorLambda.Parameters[0]; - var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); - var selectorContext = context.WithSymbol(selectorParameterSymbol); - var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); + var inputAst = sourceTranslation.Ast; - inputAst = AstExpression.Map( - input: sourceTranslation.Ast, - @as: selectorParameterSymbol.Var, - @in: selectorTranslation.Ast); - } + if (method.IsOneOf(__medianWithSelectorMethods)) + { + var selectorLambda = (LambdaExpression)arguments[1]; + var selectorParameter = selectorLambda.Parameters[0]; + var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorContext = context.WithSymbol(selectorParameterSymbol); + var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); - var ast = AstExpression.Median(inputAst); - var serializer = BsonSerializer.LookupSerializer(expression.Type); - return new TranslatedExpression(expression, ast, serializer); + inputAst = AstExpression.Map( + input: sourceTranslation.Ast, + @as: selectorParameterSymbol.Var, + @in: selectorTranslation.Ast); } + + var ast = AstExpression.Median(inputAst); + var serializer = BsonSerializer.LookupSerializer(expression.Type); + return new TranslatedExpression(expression, ast, serializer); } if (WindowMethodToAggregationExpressionTranslator.CanTranslate(expression)) @@ -67,10 +102,5 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } - - private static bool IsMedianMethod(MethodInfo methodInfo) - { - return methodInfo.DeclaringType == typeof(MongoEnumerable) && methodInfo.Name == "Median"; - } } } \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs index 8e064834779..36cec69813d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -18,49 +18,84 @@ using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { internal class PercentileMethodToAggregationExpressionTranslator { + private static readonly MethodInfo[] __percentileMethods = + { + EnumerableMethod.PercentileDecimal, + EnumerableMethod.PercentileDecimalWithSelector, + EnumerableMethod.PercentileDouble, + EnumerableMethod.PercentileDoubleWithSelector, + EnumerableMethod.PercentileInt32, + EnumerableMethod.PercentileInt32WithSelector, + EnumerableMethod.PercentileInt64, + EnumerableMethod.PercentileInt64WithSelector, + EnumerableMethod.PercentileNullableDecimal, + EnumerableMethod.PercentileNullableDecimalWithSelector, + EnumerableMethod.PercentileNullableDouble, + EnumerableMethod.PercentileNullableDoubleWithSelector, + EnumerableMethod.PercentileNullableInt32, + EnumerableMethod.PercentileNullableInt32WithSelector, + EnumerableMethod.PercentileNullableInt64, + EnumerableMethod.PercentileNullableInt64WithSelector, + EnumerableMethod.PercentileNullableSingle, + EnumerableMethod.PercentileNullableSingleWithSelector, + EnumerableMethod.PercentileSingle, + EnumerableMethod.PercentileSingleWithSelector + }; + + private static readonly MethodInfo[] __percentileWithSelectorMethods = + { + EnumerableMethod.PercentileDecimalWithSelector, + EnumerableMethod.PercentileDoubleWithSelector, + EnumerableMethod.PercentileInt32WithSelector, + EnumerableMethod.PercentileInt64WithSelector, + EnumerableMethod.PercentileNullableDecimalWithSelector, + EnumerableMethod.PercentileNullableDoubleWithSelector, + EnumerableMethod.PercentileNullableInt32WithSelector, + EnumerableMethod.PercentileNullableInt64WithSelector, + EnumerableMethod.PercentileNullableSingleWithSelector, + EnumerableMethod.PercentileSingleWithSelector + }; + public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (IsPercentileMethod(method)) + if (method.IsOneOf(__percentileMethods)) { - if (arguments.Count is 2 or 3) - { - var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); - NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + var sourceExpression = arguments[0]; + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - AstExpression inputAst = sourceTranslation.Ast; + var inputAst = sourceTranslation.Ast; - // handle selector - if (arguments.Count == 3) - { - var selectorLambda = (LambdaExpression)arguments[1]; - var selectorParameter = selectorLambda.Parameters[0]; - var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); - var selectorContext = context.WithSymbol(selectorParameterSymbol); - var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); + if (method.IsOneOf(__percentileWithSelectorMethods)) + { + var selectorLambda = (LambdaExpression)arguments[1]; + var selectorParameter = selectorLambda.Parameters[0]; + var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorContext = context.WithSymbol(selectorParameterSymbol); + var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); - inputAst = AstExpression.Map( - input: sourceTranslation.Ast, - @as: selectorParameterSymbol.Var, - @in: selectorTranslation.Ast); - } + inputAst = AstExpression.Map( + input: sourceTranslation.Ast, + @as: selectorParameterSymbol.Var, + @in: selectorTranslation.Ast); + } - var percentilesExpression = arguments[arguments.Count - 1]; - var percentilesTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, percentilesExpression); + var percentilesExpression = arguments[arguments.Count - 1]; + var percentilesTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, percentilesExpression); - var ast = AstExpression.Percentile(inputAst, percentilesTranslation.Ast); - var serializer = BsonSerializer.LookupSerializer(expression.Type); - return new TranslatedExpression(expression, ast, serializer); - } + var ast = AstExpression.Percentile(inputAst, percentilesTranslation.Ast); + var serializer = BsonSerializer.LookupSerializer(expression.Type); + return new TranslatedExpression(expression, ast, serializer); } if (WindowMethodToAggregationExpressionTranslator.CanTranslate(expression)) @@ -70,10 +105,5 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } - - private static bool IsPercentileMethod(MethodInfo methodInfo) - { - return methodInfo.DeclaringType == typeof(MongoEnumerable) && methodInfo.Name == "Percentile"; - } } } \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs index d274c3a7ae1..faf192c14f7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs @@ -94,7 +94,27 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.Last, WindowMethod.Locf, WindowMethod.Max, + WindowMethod.MedianWithDecimal, + WindowMethod.MedianWithDouble, + WindowMethod.MedianWithInt32, + WindowMethod.MedianWithInt64, + WindowMethod.MedianWithNullableDecimal, + WindowMethod.MedianWithNullableDouble, + WindowMethod.MedianWithNullableInt32, + WindowMethod.MedianWithNullableInt64, + WindowMethod.MedianWithNullableSingle, + WindowMethod.MedianWithSingle, WindowMethod.Min, + WindowMethod.PercentileWithDecimal, + WindowMethod.PercentileWithDouble, + WindowMethod.PercentileWithInt32, + WindowMethod.PercentileWithInt64, + WindowMethod.PercentileWithNullableDecimal, + WindowMethod.PercentileWithNullableDouble, + WindowMethod.PercentileWithNullableInt32, + WindowMethod.PercentileWithNullableInt64, + WindowMethod.PercentileWithNullableSingle, + WindowMethod.PercentileWithSingle, WindowMethod.Push, WindowMethod.Rank, WindowMethod.Shift, @@ -253,9 +273,33 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.ShiftWithDefaultValue }; + private static readonly MethodInfo[] __quantileMethods = + [ + WindowMethod.MedianWithDecimal, + WindowMethod.MedianWithDouble, + WindowMethod.MedianWithInt32, + WindowMethod.MedianWithInt64, + WindowMethod.MedianWithNullableDecimal, + WindowMethod.MedianWithNullableDouble, + WindowMethod.MedianWithNullableInt32, + WindowMethod.MedianWithNullableInt64, + WindowMethod.MedianWithNullableSingle, + WindowMethod.MedianWithSingle, + WindowMethod.PercentileWithDecimal, + WindowMethod.PercentileWithDouble, + WindowMethod.PercentileWithInt32, + WindowMethod.PercentileWithInt64, + WindowMethod.PercentileWithNullableDecimal, + WindowMethod.PercentileWithNullableDouble, + WindowMethod.PercentileWithNullableInt32, + WindowMethod.PercentileWithNullableInt64, + WindowMethod.PercentileWithNullableSingle, + WindowMethod.PercentileWithSingle + ]; + public static bool CanTranslate(MethodCallExpression expression) { - return IsQuantileMethod(expression.Method) || expression.Method.IsOneOf(__windowMethods); + return expression.Method.IsOneOf(__windowMethods); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -264,7 +308,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var parameters = method.GetParameters(); var arguments = expression.Arguments.ToArray(); - if ( IsQuantileMethod(method) || method.IsOneOf(__windowMethods)) + if (method.IsOneOf(__windowMethods)) { var partitionExpression = arguments[0]; var partitionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, partitionExpression); @@ -339,7 +383,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (IsQuantileMethod(method)) + if (method.IsOneOf(__quantileMethods)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); AstExpression ast; @@ -461,11 +505,6 @@ private static bool HasArgument(ParameterInfo[] parameters, string return false; } - private static bool IsQuantileMethod(MethodInfo method) - { - return method.DeclaringType == typeof(ISetWindowFieldsPartitionExtensions) && method.Name is "Median" or "Percentile"; - } - private static void ThrowIfSelectorTranslationIsNull(TranslatedExpression selectTranslation) { if (selectTranslation == null) From 1494e7e62a0e161e443a12ccb9980f0375c0a54c Mon Sep 17 00:00:00 2001 From: adelinowona Date: Tue, 29 Jul 2025 19:37:38 -0400 Subject: [PATCH 11/15] fix complex accumulator visitor bug --- .../AstComplexAccumulatorExpression.cs | 8 ++++---- .../Ast/Visitors/AstNodeVisitor.cs | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs index d825dac5f1d..8838cfe68bd 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs @@ -23,15 +23,15 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions internal sealed class AstComplexAccumulatorExpression : AstAccumulatorExpression { private readonly AstComplexAccumulatorOperator _operator; - private readonly Dictionary _args; + private readonly IReadOnlyDictionary _args; - public AstComplexAccumulatorExpression(AstComplexAccumulatorOperator @operator, Dictionary args) + public AstComplexAccumulatorExpression(AstComplexAccumulatorOperator @operator, IReadOnlyDictionary args) { _operator = @operator; _args = Ensure.IsNotNull(args, nameof(args)); } - public Dictionary Args => _args; + public IReadOnlyDictionary Args => _args; public override AstNodeType NodeType => AstNodeType.ComplexAccumulatorExpression; @@ -53,7 +53,7 @@ public override BsonValue Render() return new BsonDocument(_operator.Render(), document); } - public AstComplexAccumulatorExpression Update(Dictionary args) + public AstComplexAccumulatorExpression Update(IReadOnlyDictionary args) { if (ReferenceEquals(args, _args)) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs index 35fc5cd922f..e74b893daeb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs @@ -209,9 +209,26 @@ public virtual AstNode VisitComplexAccumulatorExpression(AstComplexAccumulatorEx if (newArg != oldArg) { - newArgs ??= new Dictionary(node.Args); + if (newArgs == null) + { + // First change detected - copy all processed entries + newArgs = new Dictionary(); + foreach (var processedKvp in node.Args) + { + if (processedKvp.Key == kvp.Key) + { + break; // Stop at current entry + } + newArgs[processedKvp.Key] = processedKvp.Value; + } + } newArgs[kvp.Key] = newArg; } + else if (newArgs != null) + { + // We're building a new dictionary, so add unchanged entries too + newArgs[kvp.Key] = oldArg; + } } return newArgs != null ? node.Update(newArgs) : node; From fb5c92d15b8e01ee2e64262cc9ef88e8854d71eb Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 31 Jul 2025 20:13:46 -0400 Subject: [PATCH 12/15] use separate classes for median and percentile accumulator ASTs --- .../Linq3Implementation/Ast/AstNodeType.cs | 3 +- .../AstComplexAccumulatorExpression.cs | 66 ----------------- .../AstComplexAccumulatorOperator.cs | 38 ---------- .../Ast/Expressions/AstExpression.cs | 15 ++-- .../AstMedianAccumulatorExpression.cs | 64 ++++++++++++++++ .../AstPercentileAccumulatorExpression.cs | 68 +++++++++++++++++ .../AstGroupingPipelineOptimizer.cs | 74 ++++++------------- .../Ast/Visitors/AstNodeVisitor.cs | 45 +++-------- 8 files changed, 177 insertions(+), 196 deletions(-) delete mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs delete mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileAccumulatorExpression.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs index bc46e23f5a7..147bd427729 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/AstNodeType.cs @@ -31,7 +31,6 @@ internal enum AstNodeType BucketStage, CollStatsStage, ComparisonFilterOperation, - ComplexAccumulatorExpression, ComputedArrayExpression, ComputedDocumentExpression, ComputedField, @@ -95,6 +94,7 @@ internal enum AstNodeType MatchesNothingFilter, MatchStage, MedianExpression, + MedianAccumulatorExpression, MedianWindowExpression, MergeStage, ModFilterOperation, @@ -108,6 +108,7 @@ internal enum AstNodeType OrFilter, OutStage, PercentileExpression, + PercentileAccumulatorExpression, PercentileWindowExpression, PickAccumulatorExpression, PickExpression, diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs deleted file mode 100644 index 8838cfe68bd..00000000000 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorExpression.cs +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2010-present MongoDB Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -using System.Collections.Generic; -using MongoDB.Bson; -using MongoDB.Driver.Core.Misc; -using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; - -namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions -{ - internal sealed class AstComplexAccumulatorExpression : AstAccumulatorExpression - { - private readonly AstComplexAccumulatorOperator _operator; - private readonly IReadOnlyDictionary _args; - - public AstComplexAccumulatorExpression(AstComplexAccumulatorOperator @operator, IReadOnlyDictionary args) - { - _operator = @operator; - _args = Ensure.IsNotNull(args, nameof(args)); - } - - public IReadOnlyDictionary Args => _args; - - public override AstNodeType NodeType => AstNodeType.ComplexAccumulatorExpression; - - public override AstNode Accept(AstNodeVisitor visitor) - { - return visitor.VisitComplexAccumulatorExpression(this); - } - - public override BsonValue Render() - { - var document = new BsonDocument(); - - // Add all accumulator parameters - foreach (var kvp in _args) - { - document[kvp.Key] = kvp.Value.Render(); - } - - return new BsonDocument(_operator.Render(), document); - } - - public AstComplexAccumulatorExpression Update(IReadOnlyDictionary args) - { - if (ReferenceEquals(args, _args)) - { - return this; - } - - return new AstComplexAccumulatorExpression(_operator, args); - } - } -} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs deleted file mode 100644 index 01f34827135..00000000000 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstComplexAccumulatorOperator.cs +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2010-present MongoDB Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -using System; - -namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions -{ - internal enum AstComplexAccumulatorOperator - { - Median, - Percentile - } - - internal static class AstComplexAccumulatorOperatorExtensions - { - public static string Render(this AstComplexAccumulatorOperator @operator) - { - return @operator switch - { - AstComplexAccumulatorOperator.Median => "$median", - AstComplexAccumulatorOperator.Percentile => "$percentile", - _ => throw new InvalidOperationException($"Unexpected complex accumulator operator: {@operator}.") - }; - } - } -} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs index b5f6381f9ff..9e464cb2f8e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs @@ -209,11 +209,6 @@ public static AstExpression Comparison(AstBinaryOperator comparisonOperator, Ast }; } - public static AstComplexAccumulatorExpression ComplexAccumulator(AstComplexAccumulatorOperator @operator, Dictionary args) - { - return new AstComplexAccumulatorExpression(@operator, args); - } - public static AstExpression ComputedArray(IEnumerable items) { return new AstComputedArrayExpression(items); @@ -607,6 +602,11 @@ public static AstExpression Median(AstExpression input) return new AstMedianExpression(input); } + public static AstMedianAccumulatorExpression MedianAccumulator(AstExpression input) + { + return new AstMedianAccumulatorExpression(input); + } + public static AstMedianWindowExpression MedianWindowExpression(AstExpression input, AstWindow window) { return new AstMedianWindowExpression(input, window); @@ -673,6 +673,11 @@ public static AstPercentileExpression Percentile(AstExpression input, AstExpress return new AstPercentileExpression(input, percentiles); } + public static AstPercentileAccumulatorExpression PercentileAccumulator(AstExpression input, AstExpression percentiles) + { + return new AstPercentileAccumulatorExpression(input, percentiles); + } + public static AstPercentileWindowExpression PercentileWindowExpression(AstExpression input, AstExpression percentiles, AstWindow window) { return new AstPercentileWindowExpression(input, percentiles, window); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs new file mode 100644 index 00000000000..754b1c175f6 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs @@ -0,0 +1,64 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstMedianAccumulatorExpression : AstAccumulatorExpression + { + private readonly AstExpression _input; + + public AstMedianAccumulatorExpression(AstExpression input) + { + _input = Ensure.IsNotNull(input, nameof(input)); + } + + public AstExpression Input => _input; + + public override AstNodeType NodeType => AstNodeType.MedianAccumulatorExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitMedianAccumulatorExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$median", new BsonDocument + { + { "input", _input.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + } + }; + } + + public AstMedianAccumulatorExpression Update(AstExpression input) + { + if (input == _input) + { + return this; + } + return new AstMedianAccumulatorExpression(input); + } + + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileAccumulatorExpression.cs new file mode 100644 index 00000000000..77cee275402 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstPercentileAccumulatorExpression.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions +{ + internal sealed class AstPercentileAccumulatorExpression : AstAccumulatorExpression + { + private readonly AstExpression _input; + private readonly AstExpression _percentiles; + + public AstPercentileAccumulatorExpression(AstExpression input, AstExpression percentiles) + { + _input = Ensure.IsNotNull(input, nameof(input)); + _percentiles = Ensure.IsNotNull(percentiles, nameof(percentiles)); + } + + public AstExpression Input => _input; + + public AstExpression Percentiles => _percentiles; + + public override AstNodeType NodeType => AstNodeType.PercentileAccumulatorExpression; + + public override AstNode Accept(AstNodeVisitor visitor) + { + return visitor.VisitPercentileAccumulatorExpression(this); + } + + public override BsonValue Render() + { + return new BsonDocument + { + { + "$percentile", new BsonDocument + { + { "input", _input.Render() }, + { "p", _percentiles.Render() }, + { "method", "approximate" } // server requires this parameter but currently only allows this value + } + } + }; + } + + public AstPercentileAccumulatorExpression Update(AstExpression input, AstExpression percentiles) + { + if (input == _input && percentiles == _percentiles) + { + return this; + } + return new AstPercentileAccumulatorExpression(input, percentiles); + } + } +} \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs index 15f8048e403..66763cabde2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs @@ -404,11 +404,10 @@ unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression && public override AstNode VisitMapExpression(AstMapExpression node) { // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0" - if (IsElementsField(node.Input)) + if (IsMappedElementsField(node, out var rewrittenArg)) { - var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element)); var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg); - return CreateOptimizedExpression(accumulatorExpression); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } return base.VisitMapExpression(node); @@ -419,28 +418,16 @@ public override AstNode VisitMedianExpression(AstMedianExpression node) // { $median : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, method: "approximate" } } => { __agg0 : { $median : { input: element, method: "approximate" } } } + "$__agg0" if (IsElementsField(node.Input)) { - var accumulator = AstExpression.ComplexAccumulator( - AstComplexAccumulatorOperator.Median, - new Dictionary - { - ["input"] = _element, - ["method"] = "approximate" - }); - return CreateOptimizedExpression(accumulator); + var accumulator = AstExpression.MedianAccumulator(_element); + return CreateGetAccumulatorFieldExpression(accumulator); } // { $median : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, method: "approximate" } } // => { __agg0 : { $median : { input: f(x => element), method: "approximate" } } } + "$__agg0" - if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg)) + if (IsMappedElementsField(node.Input, out var rewrittenArg)) { - var accumulator = AstExpression.ComplexAccumulator( - AstComplexAccumulatorOperator.Median, - new Dictionary - { - ["input"] = rewrittenArg, - ["method"] = "approximate" - }); - return CreateOptimizedExpression(accumulator); + var accumulator = AstExpression.MedianAccumulator(rewrittenArg); + return CreateGetAccumulatorFieldExpression(accumulator); } return base.VisitMedianExpression(node); @@ -452,30 +439,16 @@ public override AstNode VisitPercentileExpression(AstPercentileExpression node) // => { __agg0 : { $percentile : { input: element, p: [...], method: "approximate" } } } + "$__agg0" if (IsElementsField(node.Input)) { - var accumulator = AstExpression.ComplexAccumulator( - AstComplexAccumulatorOperator.Percentile, - new Dictionary - { - ["input"] = _element, - ["p"] = node.Percentiles, - ["method"] = "approximate" - }); - return CreateOptimizedExpression(accumulator); + var accumulator = AstExpression.PercentileAccumulator(_element, node.Percentiles); + return CreateGetAccumulatorFieldExpression(accumulator); } // { $percentile : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, p: [...], method: "approximate" } } // => { __agg0 : { $percentile : { input: f(x => element), p: [...], method: "approximate" } } } + "$__agg0" - if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg)) + if (IsMappedElementsField(node.Input, out var rewrittenArg)) { - var accumulator = AstExpression.ComplexAccumulator( - AstComplexAccumulatorOperator.Percentile, - new Dictionary - { - ["input"] = rewrittenArg, - ["p"] = node.Percentiles, - ["method"] = "approximate" - }); - return CreateOptimizedExpression(accumulator); + var accumulator = AstExpression.PercentileAccumulator(rewrittenArg, node.Percentiles); + return CreateGetAccumulatorFieldExpression(accumulator); } return base.VisitPercentileExpression(node); @@ -490,7 +463,7 @@ public override AstNode VisitPickExpression(AstPickExpression node) var @operator = node.Operator.ToAccumulatorOperator(); var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element)); var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N); - return CreateOptimizedExpression(accumulatorExpression); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } return base.VisitPickExpression(node); @@ -505,7 +478,7 @@ public override AstNode VisitUnaryExpression(AstUnaryExpression node) argGetFieldExpression.FieldName.IsStringConstant("_elements")) { var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1); - return CreateOptimizedExpression(accumulatorExpression); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } } @@ -513,15 +486,15 @@ public override AstNode VisitUnaryExpression(AstUnaryExpression node) if (node.Operator.IsAccumulator(out var accumulatorOperator) && IsElementsField(node.Arg)) { var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element); - return CreateOptimizedExpression(accumulatorExpression); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" if (node.Operator.IsAccumulator(out accumulatorOperator) && - IsMappedElementsField(node.Arg, out var mapExpression, out var rewrittenArg)) + IsMappedElementsField(node.Arg, out var rewrittenArg)) { var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg); - return CreateOptimizedExpression(accumulatorExpression); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } return base.VisitUnaryExpression(node); @@ -529,26 +502,25 @@ public override AstNode VisitUnaryExpression(AstUnaryExpression node) private bool IsElementsField(AstExpression expression) { - return expression is AstGetFieldExpression getFieldExpression && - getFieldExpression.FieldName.IsStringConstant("_elements") && - getFieldExpression.Input.IsRootVar(); + return + expression is AstGetFieldExpression getFieldExpression && + getFieldExpression.FieldName.IsStringConstant("_elements") && + getFieldExpression.Input.IsRootVar(); } - private bool IsMappedElementsField(AstExpression expression, out AstMapExpression mapExpression, out AstExpression rewrittenArg) + private bool IsMappedElementsField(AstExpression expression, out AstExpression rewrittenArg) { if (expression is AstMapExpression map && IsElementsField(map.Input)) { - mapExpression = map; rewrittenArg = (AstExpression)AstNodeReplacer.Replace(map.In, (map.As, _element)); return true; } - mapExpression = null; rewrittenArg = null; return false; } - private AstExpression CreateOptimizedExpression(AstAccumulatorExpression accumulator) + private AstExpression CreateGetAccumulatorFieldExpression(AstAccumulatorExpression accumulator) { var fieldName = _accumulators.AddAccumulatorExpression(accumulator); return AstExpression.GetField(AstExpression.RootVar, fieldName); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs index e74b893daeb..83a50bd6f44 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs @@ -199,41 +199,6 @@ public virtual AstNode VisitComparisonFilterOperation(AstComparisonFilterOperati return node; } - public virtual AstNode VisitComplexAccumulatorExpression(AstComplexAccumulatorExpression node) - { - Dictionary newArgs = null; - foreach (var kvp in node.Args) - { - var oldArg = kvp.Value; - var newArg = VisitAndConvert(oldArg); - - if (newArg != oldArg) - { - if (newArgs == null) - { - // First change detected - copy all processed entries - newArgs = new Dictionary(); - foreach (var processedKvp in node.Args) - { - if (processedKvp.Key == kvp.Key) - { - break; // Stop at current entry - } - newArgs[processedKvp.Key] = processedKvp.Value; - } - } - newArgs[kvp.Key] = newArg; - } - else if (newArgs != null) - { - // We're building a new dictionary, so add unchanged entries too - newArgs[kvp.Key] = oldArg; - } - } - - return newArgs != null ? node.Update(newArgs) : node; - } - public virtual AstNode VisitComputedArrayExpression(AstComputedArrayExpression node) { return node.Update(VisitAndConvert(node.Items)); @@ -544,6 +509,11 @@ public virtual AstNode VisitMedianExpression(AstMedianExpression node) return node.Update(VisitAndConvert(node.Input)); } + public virtual AstNode VisitMedianAccumulatorExpression(AstMedianAccumulatorExpression node) + { + return node.Update(VisitAndConvert(node.Input)); + } + public virtual AstNode VisitMedianWindowExpression(AstMedianWindowExpression node) { return node.Update(VisitAndConvert(node.Input), node.Window); @@ -609,6 +579,11 @@ public virtual AstNode VisitPercentileExpression(AstPercentileExpression node) return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.Percentiles)); } + public virtual AstNode VisitPercentileAccumulatorExpression(AstPercentileAccumulatorExpression node) + { + return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.Percentiles)); + } + public virtual AstNode VisitPercentileWindowExpression(AstPercentileWindowExpression node) { return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.Percentiles), node.Window); From cb03c991f73334c9170d0b05d31752221521d0b3 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 31 Jul 2025 20:14:28 -0400 Subject: [PATCH 13/15] use StandardSerializers --- .../MedianMethodToAggregationExpressionTranslator.cs | 12 ++++++------ ...centileMethodToAggregationExpressionTranslator.cs | 11 ++++++----- .../WindowMethodToAggregationExpressionTranslator.cs | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs index fcede33cd25..b8029504b32 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -15,17 +15,17 @@ using System.Linq.Expressions; using System.Reflection; -using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { internal class MedianMethodToAggregationExpressionTranslator { private static readonly MethodInfo[] __medianMethods = - { + [ EnumerableMethod.MedianDecimal, EnumerableMethod.MedianDecimalWithSelector, EnumerableMethod.MedianDouble, @@ -46,10 +46,10 @@ internal class MedianMethodToAggregationExpressionTranslator EnumerableMethod.MedianNullableSingleWithSelector, EnumerableMethod.MedianSingle, EnumerableMethod.MedianSingleWithSelector - }; + ]; private static readonly MethodInfo[] __medianWithSelectorMethods = - { + [ EnumerableMethod.MedianDecimalWithSelector, EnumerableMethod.MedianDoubleWithSelector, EnumerableMethod.MedianInt32WithSelector, @@ -60,7 +60,7 @@ internal class MedianMethodToAggregationExpressionTranslator EnumerableMethod.MedianNullableInt64WithSelector, EnumerableMethod.MedianNullableSingleWithSelector, EnumerableMethod.MedianSingleWithSelector - }; + ]; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { @@ -91,7 +91,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } var ast = AstExpression.Median(inputAst); - var serializer = BsonSerializer.LookupSerializer(expression.Type); + var serializer = StandardSerializers.GetSerializer(expression.Type); return new TranslatedExpression(expression, ast, serializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs index 36cec69813d..cdbbedb3cc6 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -19,13 +19,14 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { internal class PercentileMethodToAggregationExpressionTranslator { private static readonly MethodInfo[] __percentileMethods = - { + [ EnumerableMethod.PercentileDecimal, EnumerableMethod.PercentileDecimalWithSelector, EnumerableMethod.PercentileDouble, @@ -46,10 +47,10 @@ internal class PercentileMethodToAggregationExpressionTranslator EnumerableMethod.PercentileNullableSingleWithSelector, EnumerableMethod.PercentileSingle, EnumerableMethod.PercentileSingleWithSelector - }; + ]; private static readonly MethodInfo[] __percentileWithSelectorMethods = - { + [ EnumerableMethod.PercentileDecimalWithSelector, EnumerableMethod.PercentileDoubleWithSelector, EnumerableMethod.PercentileInt32WithSelector, @@ -60,7 +61,7 @@ internal class PercentileMethodToAggregationExpressionTranslator EnumerableMethod.PercentileNullableInt64WithSelector, EnumerableMethod.PercentileNullableSingleWithSelector, EnumerableMethod.PercentileSingleWithSelector - }; + ]; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { @@ -94,7 +95,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var percentilesTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, percentilesExpression); var ast = AstExpression.Percentile(inputAst, percentilesTranslation.Ast); - var serializer = BsonSerializer.LookupSerializer(expression.Type); + var serializer = StandardSerializers.GetSerializer(expression.Type); return new TranslatedExpression(expression, ast, serializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs index faf192c14f7..f45cffc3e49 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs @@ -400,7 +400,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC ast = AstExpression.MedianWindowExpression(selectorTranslation.Ast, window); } - var serializer = BsonSerializer.LookupSerializer(method.ReturnType); + var serializer = StandardSerializers.GetSerializer(method.ReturnType); return new TranslatedExpression(expression, ast, serializer); } From a4a2288393d4d82e047c7999d877add6fd994c01 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Mon, 4 Aug 2025 17:45:40 -0400 Subject: [PATCH 14/15] update return types and tests --- .../ISetWindowFieldsPartitionExtensions.cs | 16 +- src/MongoDB.Driver/Linq/MongoEnumerable.cs | 32 +-- ...dToAggregationExpressionTranslatorTests.cs | 16 +- ...dToAggregationExpressionTranslatorTests.cs | 206 ++++++++++++++++-- ...dToAggregationExpressionTranslatorTests.cs | 110 ++++++++-- 5 files changed, 314 insertions(+), 66 deletions(-) diff --git a/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs b/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs index efc42a0cd48..ecba3d5e81e 100644 --- a/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs +++ b/src/MongoDB.Driver/Linq/ISetWindowFieldsPartitionExtensions.cs @@ -886,7 +886,7 @@ public static TValue Max(this ISetWindowFieldsPartition /// The selector that selects a value from the input document. /// The window boundaries. /// The median of the selected values. - public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + public static decimal Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -899,7 +899,7 @@ public static double Median(this ISetWindowFieldsPartition parti /// The selector that selects a value from the input document. /// The window boundaries. /// The median of the selected values. - public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + public static decimal? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -938,7 +938,7 @@ public static double Median(this ISetWindowFieldsPartition parti /// The selector that selects a value from the input document. /// The window boundaries. /// The median of the selected values. - public static double Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + public static float Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -951,7 +951,7 @@ public static double Median(this ISetWindowFieldsPartition parti /// The selector that selects a value from the input document. /// The window boundaries. /// The median of the selected values. - public static double? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) + public static float? Median(this ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -1031,7 +1031,7 @@ public static TValue Min(this ISetWindowFieldsPartition /// The percentiles (between 0.0 and 1.0). /// The window boundaries. /// The values at the given percentiles. - public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + public static decimal[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -1045,7 +1045,7 @@ public static double[] Percentile(this ISetWindowFieldsPartition /// The percentiles (between 0.0 and 1.0). /// The window boundaries. /// The values at the given percentiles. - public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + public static decimal?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -1087,7 +1087,7 @@ public static double[] Percentile(this ISetWindowFieldsPartition /// The percentiles (between 0.0 and 1.0). /// The window boundaries. /// The values at the given percentiles. - public static double[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + public static float[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } @@ -1101,7 +1101,7 @@ public static double[] Percentile(this ISetWindowFieldsPartition /// The percentiles (between 0.0 and 1.0). /// The window boundaries. /// The values at the given percentiles. - public static double?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) + public static float?[] Percentile(this ISetWindowFieldsPartition partition, Func selector, IEnumerable percentiles, SetWindowFieldsWindow window = null) { throw new InvalidOperationException("This method is only intended to be used with SetWindowFields."); } diff --git a/src/MongoDB.Driver/Linq/MongoEnumerable.cs b/src/MongoDB.Driver/Linq/MongoEnumerable.cs index f54586bff8c..eda95082fe0 100644 --- a/src/MongoDB.Driver/Linq/MongoEnumerable.cs +++ b/src/MongoDB.Driver/Linq/MongoEnumerable.cs @@ -241,7 +241,7 @@ public static IEnumerable MaxN( /// /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source) + public static decimal Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -251,7 +251,7 @@ public static double Median(this IEnumerable source) /// /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source) + public static decimal? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -281,7 +281,7 @@ public static double Median(this IEnumerable source) /// /// The sequence of values. /// The median value. - public static double Median(this IEnumerable source) + public static float Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -291,7 +291,7 @@ public static double Median(this IEnumerable source) /// /// The sequence of values. /// The median value. - public static double? Median(this IEnumerable source) + public static float? Median(this IEnumerable source) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -343,7 +343,7 @@ public static double Median(this IEnumerable source) /// A sequence of values to calculate the median of. /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static decimal Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -355,7 +355,7 @@ public static double Median(this IEnumerable source, FuncA sequence of values to calculate the median of. /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static decimal? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -391,7 +391,7 @@ public static double Median(this IEnumerable source, FuncA sequence of values to calculate the median of. /// A transform function to apply to each element. /// The median value. - public static double Median(this IEnumerable source, Func selector) + public static float Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -403,7 +403,7 @@ public static double Median(this IEnumerable source, FuncA sequence of values to calculate the median of. /// A transform function to apply to each element. /// The median value. - public static double? Median(this IEnumerable source, Func selector) + public static float? Median(this IEnumerable source, Func selector) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -499,7 +499,7 @@ public static IEnumerable MinN( /// A sequence of values to calculate the percentiles of. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static decimal[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -510,7 +510,7 @@ public static double[] Percentile(this IEnumerable source, IEnumerable< /// A sequence of values to calculate the percentiles of. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static decimal?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -543,7 +543,7 @@ public static double[] Percentile(this IEnumerable source, IEnumerableA sequence of values to calculate the percentiles of. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static float[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -554,7 +554,7 @@ public static double[] Percentile(this IEnumerable source, IEnumerableA sequence of values to calculate the percentiles of. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, IEnumerable percentiles) + public static float?[] Percentile(this IEnumerable source, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -611,7 +611,7 @@ public static double[] Percentile(this IEnumerable source, IEnumerableA transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static decimal[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -624,7 +624,7 @@ public static double[] Percentile(this IEnumerable source, Fun /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static decimal?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -663,7 +663,7 @@ public static double[] Percentile(this IEnumerable source, Fun /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static float[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } @@ -676,7 +676,7 @@ public static double[] Percentile(this IEnumerable source, Fun /// A transform function to apply to each element. /// The percentiles to compute (each between 0.0 and 1.0). /// The percentiles of the sequence of values. - public static double?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) + public static float?[] Percentile(this IEnumerable source, Func selector, IEnumerable percentiles) { throw CustomLinqExtensionMethodHelper.CreateNotSupportedException(); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs index 0aaaefce86b..2013568bb50 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslatorTests.cs @@ -48,7 +48,7 @@ public void Median_with_decimals_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : '$Decimals', method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(1.0, 1.0, 2.0); + results.Should().Equal(1.0M, 1.0M, 2.0M); } [Theory] @@ -66,7 +66,7 @@ public void Median_with_decimals_selector_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Decimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(2.0, 2.0, 4.0); + results.Should().Equal(2.0M, 2.0M, 4.0M); } [Theory] @@ -120,7 +120,7 @@ public void Median_with_floats_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : '$Floats', method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(1.0, 1.0, 2.0); + results.Should().Equal(1.0F, 1.0F, 2.0F); } [Theory] @@ -138,7 +138,7 @@ public void Median_with_floats_selector_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$Floats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(2.0, 2.0, 4.0); + results.Should().Equal(2.0F, 2.0F, 4.0F); } [Theory] @@ -228,7 +228,7 @@ public void Median_with_nullable_decimals_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableDecimals', method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(null, null, 2.0); + results.Should().Equal(null, null, 2.0M); } [Theory] @@ -246,7 +246,7 @@ public void Median_with_nullable_decimals_selector_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableDecimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(null, null, 4.0); + results.Should().Equal(null, null, 4.0M); } [Theory] @@ -300,7 +300,7 @@ public void Median_with_nullable_floats_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : '$NullableFloats', method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(null, null, 2.0); + results.Should().Equal(null, null, 2.0F); } [Theory] @@ -318,7 +318,7 @@ public void Median_with_nullable_floats_selector_should_work( AssertStages(stages, "{ $project : { _v : { $median : { input : { $map : { input : '$NullableFloats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results.Should().Equal(null, null, 4.0); + results.Should().Equal(null, null, 4.0F); } [Theory] diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs index bcc939d9f7f..c0f1225c245 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslatorTests.cs @@ -48,9 +48,9 @@ public void Percentile_with_decimals_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Decimals', p : [0.5], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal(1.0); - results[1].Should().Equal(1.0); - results[2].Should().Equal(2.0); + results[0].Should().Equal(1.0M); + results[1].Should().Equal(1.0M); + results[2].Should().Equal(2.0M); } [Theory] @@ -68,9 +68,9 @@ public void Percentile_with_decimals_multiple_percentiles_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Decimals', p : [0.25, 0.75], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal(1.0, 1.0); - results[1].Should().Equal(1.0, 2.0); - results[2].Should().Equal(1.0, 3.0); + results[0].Should().Equal(1.0M, 1.0M); + results[1].Should().Equal(1.0M, 2.0M); + results[2].Should().Equal(1.0M, 3.0M); } [Theory] @@ -88,9 +88,9 @@ public void Percentile_with_decimals_selector_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Decimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal(2.0); - results[1].Should().Equal(2.0); - results[2].Should().Equal(4.0); + results[0].Should().Equal(2.0M); + results[1].Should().Equal(2.0M); + results[2].Should().Equal(4.0M); } [Theory] @@ -148,9 +148,9 @@ public void Percentile_with_floats_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Floats', p : [0.5], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal(1.0); - results[1].Should().Equal(1.0); - results[2].Should().Equal(2.0); + results[0].Should().Equal(1.0F); + results[1].Should().Equal(1.0F); + results[2].Should().Equal(2.0F); } [Theory] @@ -168,9 +168,9 @@ public void Percentile_with_floats_selector_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$Floats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal(2.0); - results[1].Should().Equal(2.0); - results[2].Should().Equal(4.0); + results[0].Should().Equal(2.0F); + results[1].Should().Equal(2.0F); + results[2].Should().Equal(4.0F); } [Theory] @@ -268,9 +268,9 @@ public void Percentile_with_nullable_decimals_should_work( AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableDecimals', p : [0.5], method : 'approximate' } }, _id : 0 } }"); var results = queryable.ToList(); - results[0].Should().Equal((double?)null); - results[1].Should().Equal((double?)null); - results[2].Should().Equal(2.0); + results[0].Should().Equal((decimal?)null); + results[1].Should().Equal((decimal?)null); + results[2].Should().Equal(2.0M); } [Theory] @@ -287,19 +287,183 @@ public void Percentile_with_nullable_decimals_selector_should_work( var stages = Translate(collection, queryable); AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableDecimals', as : 'y', in : { $multiply : ['$$y', NumberDecimal(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + var results = queryable.ToList(); + results[0].Should().Equal((decimal?)null); + results[1].Should().Equal((decimal?)null); + results[2].Should().Equal(4.0M); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_doubles_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDoubles.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableDoubles.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableDoubles', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_doubles_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableDoubles.AsQueryable().Percentile(y => y * 2.0, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableDoubles.Percentile(y => y * 2.0, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableDoubles', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + var results = queryable.ToList(); results[0].Should().Equal((double?)null); results[1].Should().Equal((double?)null); results[2].Should().Equal(4.0); } - [Fact] - public void Percentile_with_list_input_should_work() + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_floats_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableFloats.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableFloats.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableFloats', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((float?)null); + results[1].Should().Equal((float?)null); + results[2].Should().Equal(2.0F); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_floats_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableFloats.AsQueryable().Percentile(y => y * 2.0F, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableFloats.Percentile(y => y * 2.0F, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableFloats', as : 'y', in : { $multiply : ['$$y', 2.0] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((float?)null); + results[1].Should().Equal((float?)null); + results[2].Should().Equal(4.0F); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_ints_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableInts.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableInts.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableInts', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_ints_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableInts.AsQueryable().Percentile(y => y * 2, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableInts.Percentile(y => y * 2, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableInts', as : 'y', in : { $multiply : ['$$y', 2] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_longs_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableLongs.AsQueryable().Percentile(new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableLongs.Percentile(new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$NullableLongs', p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(2.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_nullable_longs_selector_should_work( + [Values(false, true)] bool withNestedAsQueryable) + { + var collection = Fixture.Collection; + + var queryable = withNestedAsQueryable ? + collection.AsQueryable().Select(x => x.NullableLongs.AsQueryable().Percentile(y => y * 2L, new[] { 0.5 })) : + collection.AsQueryable().Select(x => x.NullableLongs.Percentile(y => y * 2L, new[] { 0.5 })); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $percentile : { input : { $map : { input : '$NullableLongs', as : 'y', in : { $multiply : ['$$y', NumberLong(2)] } } }, p : [0.5], method : 'approximate' } }, _id : 0 } }"); + + var results = queryable.ToList(); + results[0].Should().Equal((double?)null); + results[1].Should().Equal((double?)null); + results[2].Should().Equal(4.0); + } + + [Theory] + [ParameterAttributeData] + public void Percentile_with_list_input_should_work( + [Values(false, true)] bool withNestedAsQueryable) { var collection = Fixture.Collection; var percentiles = new List { 0.25, 0.5, 0.75 }; - var queryable = collection.AsQueryable().Select(x => x.Doubles.Percentile(percentiles)); + var queryable = withNestedAsQueryable + ? collection.AsQueryable().Select(x => x.Doubles.AsQueryable().Percentile(percentiles)) + : collection.AsQueryable().Select(x => x.Doubles.Percentile(percentiles)); var stages = Translate(collection, queryable); AssertStages(stages, "{ $project : { _v : { $percentile : { input : '$Doubles', p : [0.25, 0.5, 0.75], method : 'approximate' } }, _id : 0 } }"); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs index 31a511b5a67..b34c3f77e3b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslatorTests.cs @@ -1491,7 +1491,7 @@ public void Translate_should_return_expected_result_for_Max() [Fact] public void Translate_should_return_expected_result_for_Median_with_Decimal() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1505,14 +1505,14 @@ public void Translate_should_return_expected_result_for_Median_with_Decimal() var results = aggregate.ToList(); foreach (var result in results) { - result["Result"].AsDouble.Should().Be(2.0); + result["Result"].ToDecimal().Should().Be(2.0M); } } [Fact] public void Translate_should_return_expected_result_for_Median_with_Double() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1533,7 +1533,7 @@ public void Translate_should_return_expected_result_for_Median_with_Double() [Fact] public void Translate_should_return_expected_result_for_Median_with_Int32() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1554,7 +1554,7 @@ public void Translate_should_return_expected_result_for_Median_with_Int32() [Fact] public void Translate_should_return_expected_result_for_Median_with_Int64() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1575,7 +1575,7 @@ public void Translate_should_return_expected_result_for_Median_with_Int64() [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Decimal() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1589,14 +1589,14 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Dec var results = aggregate.ToList(); foreach (var result in results) { - result["Result"].AsDouble.Should().Be(1.0); + result["Result"].ToDecimal().Should().Be(1.0M); } } [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Double() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1617,7 +1617,7 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Dou [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Int32() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1638,7 +1638,7 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Int [Fact] public void Translate_should_return_expected_result_for_Median_with_nullable_Int64() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1656,10 +1656,52 @@ public void Translate_should_return_expected_result_for_Median_with_nullable_Int } } + [Fact] + public void Translate_should_return_expected_result_for_Median_with_nullable_Single() + { + RequireServer.Check().Supports(Feature.MedianOperator); + + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.NullableSingleField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$NullableSingleField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(1.0); + } + } + + [Fact] + public void Translate_should_return_expected_result_for_Median_with_Single() + { + RequireServer.Check().Supports(Feature.MedianOperator); + + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Median(x => x.SingleField, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $median : { input : '$SingleField', method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsDouble.Should().Be(2.0); + } + } + [Fact] public void Translate_should_return_expected_result_for_Median_with_window() { - RequireServer.Check().Supports(Feature.PercentileOperator); + RequireServer.Check().Supports(Feature.MedianOperator); var collection = Fixture.Collection; @@ -1720,7 +1762,7 @@ public void Translate_should_return_expected_result_for_Percentile_with_Decimal( var results = aggregate.ToList(); foreach (var result in results) { - result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + result["Result"].AsBsonArray[0].ToDecimal().Should().Be(2.0M); } } @@ -1804,7 +1846,7 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable var results = aggregate.ToList(); foreach (var result in results) { - result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + result["Result"].AsBsonArray[0].ToDecimal().Should().Be(1.0M); } } @@ -1871,6 +1913,27 @@ public void Translate_should_return_expected_result_for_Percentile_with_nullable } } + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_nullable_Single() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.NullableSingleField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$NullableSingleField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(1.0); + } + } + [Fact] public void Translate_should_return_expected_result_for_Percentile_with_multiple_percentiles() { @@ -1894,6 +1957,27 @@ public void Translate_should_return_expected_result_for_Percentile_with_multiple } } + [Fact] + public void Translate_should_return_expected_result_for_Percentile_with_Single() + { + RequireServer.Check().Supports(Feature.PercentileOperator); + + var collection = Fixture.Collection; + + var aggregate = collection.Aggregate() + .SetWindowFields(output: p => new { Result = p.Percentile(x => x.SingleField, new[] { 0.5 }, null) }); + + var stages = Translate(collection, aggregate); + var expectedStages = new[] { "{ $setWindowFields : { output : { Result : { $percentile : { input : '$SingleField', p : [0.5], method : 'approximate' } } } } }" }; + AssertStages(stages, expectedStages); + + var results = aggregate.ToList(); + foreach (var result in results) + { + result["Result"].AsBsonArray[0].AsDouble.Should().Be(2.0); + } + } + [Fact] public void Translate_should_return_expected_result_for_Percentile_with_window() { From 32c473f17ff154e355c5ee9f7f31895bd514d58d Mon Sep 17 00:00:00 2001 From: adelinowona Date: Fri, 8 Aug 2025 15:52:06 -0400 Subject: [PATCH 15/15] address pr comments --- .../Ast/Expressions/AstExpression.cs | 10 +- .../AstMedianAccumulatorExpression.cs | 1 - .../Ast/Expressions/AstMedianExpression.cs | 1 - .../AstGroupingPipelineOptimizer.cs | 30 +++-- .../Reflection/EnumerableMethod.cs | 120 ------------------ .../Reflection/MongoEnumerableMethod.cs | 120 ++++++++++++++++++ ...MethodToAggregationExpressionTranslator.cs | 65 +++++----- ...MethodToAggregationExpressionTranslator.cs | 66 +++++----- 8 files changed, 207 insertions(+), 206 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs index 9e464cb2f8e..bf729df32a9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs @@ -602,12 +602,12 @@ public static AstExpression Median(AstExpression input) return new AstMedianExpression(input); } - public static AstMedianAccumulatorExpression MedianAccumulator(AstExpression input) + public static AstAccumulatorExpression MedianAccumulator(AstExpression input) { return new AstMedianAccumulatorExpression(input); } - public static AstMedianWindowExpression MedianWindowExpression(AstExpression input, AstWindow window) + public static AstWindowExpression MedianWindowExpression(AstExpression input, AstWindow window) { return new AstMedianWindowExpression(input, window); } @@ -668,17 +668,17 @@ public static AstExpression Or(params AstExpression[] args) return new AstNaryExpression(AstNaryOperator.Or, flattenedArgs); } - public static AstPercentileExpression Percentile(AstExpression input, AstExpression percentiles) + public static AstExpression Percentile(AstExpression input, AstExpression percentiles) { return new AstPercentileExpression(input, percentiles); } - public static AstPercentileAccumulatorExpression PercentileAccumulator(AstExpression input, AstExpression percentiles) + public static AstAccumulatorExpression PercentileAccumulator(AstExpression input, AstExpression percentiles) { return new AstPercentileAccumulatorExpression(input, percentiles); } - public static AstPercentileWindowExpression PercentileWindowExpression(AstExpression input, AstExpression percentiles, AstWindow window) + public static AstWindowExpression PercentileWindowExpression(AstExpression input, AstExpression percentiles, AstWindow window) { return new AstPercentileWindowExpression(input, percentiles, window); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs index 754b1c175f6..d7cc2bb085a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianAccumulatorExpression.cs @@ -59,6 +59,5 @@ public AstMedianAccumulatorExpression Update(AstExpression input) } return new AstMedianAccumulatorExpression(input); } - } } \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs index 2bbbf0112ec..a171bb9fa02 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMedianExpression.cs @@ -59,6 +59,5 @@ public AstMedianExpression Update(AstExpression input) } return new AstMedianExpression(input); } - } } \ No newline at end of file diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs index 66763cabde2..37de0673850 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs @@ -415,19 +415,20 @@ public override AstNode VisitMapExpression(AstMapExpression node) public override AstNode VisitMedianExpression(AstMedianExpression node) { - // { $median : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, method: "approximate" } } => { __agg0 : { $median : { input: element, method: "approximate" } } } + "$__agg0" + // { $median : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, method: "approximate" } } + // => { __agg0 : { $median : { input: element, method: "approximate" } } } + "$__agg0" if (IsElementsField(node.Input)) { - var accumulator = AstExpression.MedianAccumulator(_element); - return CreateGetAccumulatorFieldExpression(accumulator); + var accumulatorExpression = AstExpression.MedianAccumulator(_element); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } // { $median : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, method: "approximate" } } // => { __agg0 : { $median : { input: f(x => element), method: "approximate" } } } + "$__agg0" if (IsMappedElementsField(node.Input, out var rewrittenArg)) { - var accumulator = AstExpression.MedianAccumulator(rewrittenArg); - return CreateGetAccumulatorFieldExpression(accumulator); + var accumulatorExpression = AstExpression.MedianAccumulator(rewrittenArg); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } return base.VisitMedianExpression(node); @@ -439,16 +440,16 @@ public override AstNode VisitPercentileExpression(AstPercentileExpression node) // => { __agg0 : { $percentile : { input: element, p: [...], method: "approximate" } } } + "$__agg0" if (IsElementsField(node.Input)) { - var accumulator = AstExpression.PercentileAccumulator(_element, node.Percentiles); - return CreateGetAccumulatorFieldExpression(accumulator); + var accumulatorExpression = AstExpression.PercentileAccumulator(_element, node.Percentiles); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } // { $percentile : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, p: [...], method: "approximate" } } // => { __agg0 : { $percentile : { input: f(x => element), p: [...], method: "approximate" } } } + "$__agg0" if (IsMappedElementsField(node.Input, out var rewrittenArg)) { - var accumulator = AstExpression.PercentileAccumulator(rewrittenArg, node.Percentiles); - return CreateGetAccumulatorFieldExpression(accumulator); + var accumulatorExpression = AstExpression.PercentileAccumulator(rewrittenArg, node.Percentiles); + return CreateGetAccumulatorFieldExpression(accumulatorExpression); } return base.VisitPercentileExpression(node); @@ -489,7 +490,8 @@ public override AstNode VisitUnaryExpression(AstUnaryExpression node) return CreateGetAccumulatorFieldExpression(accumulatorExpression); } - // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" + // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } + // => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0" if (node.Operator.IsAccumulator(out accumulatorOperator) && IsMappedElementsField(node.Arg, out var rewrittenArg)) { @@ -510,9 +512,9 @@ expression is AstGetFieldExpression getFieldExpression && private bool IsMappedElementsField(AstExpression expression, out AstExpression rewrittenArg) { - if (expression is AstMapExpression map && IsElementsField(map.Input)) + if (expression is AstMapExpression mapExpression && IsElementsField(mapExpression.Input)) { - rewrittenArg = (AstExpression)AstNodeReplacer.Replace(map.In, (map.As, _element)); + rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element)); return true; } @@ -520,9 +522,9 @@ private bool IsMappedElementsField(AstExpression expression, out AstExpression r return false; } - private AstExpression CreateGetAccumulatorFieldExpression(AstAccumulatorExpression accumulator) + private AstExpression CreateGetAccumulatorFieldExpression(AstAccumulatorExpression accumulatorExpression) { - var fieldName = _accumulators.AddAccumulatorExpression(accumulator); + var fieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); return AstExpression.GetField(AstExpression.RootVar, fieldName); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index c3327365d26..0ae3e99ca4a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -112,26 +112,6 @@ internal static class EnumerableMethod private static readonly MethodInfo __maxSingle; private static readonly MethodInfo __maxSingleWithSelector; private static readonly MethodInfo __maxWithSelector; - private static readonly MethodInfo __medianDecimal; - private static readonly MethodInfo __medianDecimalWithSelector; - private static readonly MethodInfo __medianDouble; - private static readonly MethodInfo __medianDoubleWithSelector; - private static readonly MethodInfo __medianInt32; - private static readonly MethodInfo __medianInt32WithSelector; - private static readonly MethodInfo __medianInt64; - private static readonly MethodInfo __medianInt64WithSelector; - private static readonly MethodInfo __medianNullableDecimal; - private static readonly MethodInfo __medianNullableDecimalWithSelector; - private static readonly MethodInfo __medianNullableDouble; - private static readonly MethodInfo __medianNullableDoubleWithSelector; - private static readonly MethodInfo __medianNullableInt32; - private static readonly MethodInfo __medianNullableInt32WithSelector; - private static readonly MethodInfo __medianNullableInt64; - private static readonly MethodInfo __medianNullableInt64WithSelector; - private static readonly MethodInfo __medianNullableSingle; - private static readonly MethodInfo __medianNullableSingleWithSelector; - private static readonly MethodInfo __medianSingle; - private static readonly MethodInfo __medianSingleWithSelector; private static readonly MethodInfo __min; private static readonly MethodInfo __minDecimal; private static readonly MethodInfo __minDecimalWithSelector; @@ -159,26 +139,6 @@ internal static class EnumerableMethod private static readonly MethodInfo __ofType; private static readonly MethodInfo __orderBy; private static readonly MethodInfo __orderByDescending; - private static readonly MethodInfo __percentileDecimal; - private static readonly MethodInfo __percentileDecimalWithSelector; - private static readonly MethodInfo __percentileDouble; - private static readonly MethodInfo __percentileDoubleWithSelector; - private static readonly MethodInfo __percentileInt32; - private static readonly MethodInfo __percentileInt32WithSelector; - private static readonly MethodInfo __percentileInt64; - private static readonly MethodInfo __percentileInt64WithSelector; - private static readonly MethodInfo __percentileNullableDecimal; - private static readonly MethodInfo __percentileNullableDecimalWithSelector; - private static readonly MethodInfo __percentileNullableDouble; - private static readonly MethodInfo __percentileNullableDoubleWithSelector; - private static readonly MethodInfo __percentileNullableInt32; - private static readonly MethodInfo __percentileNullableInt32WithSelector; - private static readonly MethodInfo __percentileNullableInt64; - private static readonly MethodInfo __percentileNullableInt64WithSelector; - private static readonly MethodInfo __percentileNullableSingle; - private static readonly MethodInfo __percentileNullableSingleWithSelector; - private static readonly MethodInfo __percentileSingle; - private static readonly MethodInfo __percentileSingleWithSelector; private static readonly MethodInfo __prepend; private static readonly MethodInfo __range; private static readonly MethodInfo __repeat; @@ -319,26 +279,6 @@ static EnumerableMethod() __maxSingle = ReflectionInfo.Method((IEnumerable source) => source.Max()); __maxSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Max(selector)); __maxWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Max(selector)); - __medianDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); - __medianSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); - __medianSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); __min = ReflectionInfo.Method((IEnumerable source) => source.Min()); __minDecimal = ReflectionInfo.Method((IEnumerable source) => source.Min()); __minDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Min(selector)); @@ -366,26 +306,6 @@ static EnumerableMethod() __ofType = ReflectionInfo.Method((IEnumerable source) => source.OfType()); __orderBy = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.OrderBy(keySelector)); __orderByDescending = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.OrderByDescending(keySelector)); - __percentileDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileNullableDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileNullableDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileNullableInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileNullableInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileNullableSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); - __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); - __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __prepend = ReflectionInfo.Method((IEnumerable source, object element) => source.Prepend(element)); __range = ReflectionInfo.Method((int start, int count) => Enumerable.Range(start, count)); __repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count)); @@ -525,26 +445,6 @@ static EnumerableMethod() public static MethodInfo MaxSingle => __maxSingle; public static MethodInfo MaxSingleWithSelector => __maxSingleWithSelector; public static MethodInfo MaxWithSelector => __maxWithSelector; - public static MethodInfo MedianDecimal => __medianDecimal; - public static MethodInfo MedianDecimalWithSelector => __medianDecimalWithSelector; - public static MethodInfo MedianDouble => __medianDouble; - public static MethodInfo MedianDoubleWithSelector => __medianDoubleWithSelector; - public static MethodInfo MedianInt32 => __medianInt32; - public static MethodInfo MedianInt32WithSelector => __medianInt32WithSelector; - public static MethodInfo MedianInt64 => __medianInt64; - public static MethodInfo MedianInt64WithSelector => __medianInt64WithSelector; - public static MethodInfo MedianNullableDecimal => __medianNullableDecimal; - public static MethodInfo MedianNullableDecimalWithSelector => __medianNullableDecimalWithSelector; - public static MethodInfo MedianNullableDouble => __medianNullableDouble; - public static MethodInfo MedianNullableDoubleWithSelector => __medianNullableDoubleWithSelector; - public static MethodInfo MedianNullableInt32 => __medianNullableInt32; - public static MethodInfo MedianNullableInt32WithSelector => __medianNullableInt32WithSelector; - public static MethodInfo MedianNullableInt64 => __medianNullableInt64; - public static MethodInfo MedianNullableInt64WithSelector => __medianNullableInt64WithSelector; - public static MethodInfo MedianNullableSingle => __medianNullableSingle; - public static MethodInfo MedianNullableSingleWithSelector => __medianNullableSingleWithSelector; - public static MethodInfo MedianSingle => __medianSingle; - public static MethodInfo MedianSingleWithSelector => __medianSingleWithSelector; public static MethodInfo Min => __min; public static MethodInfo MinDecimal => __minDecimal; public static MethodInfo MinDecimalWithSelector => __minDecimalWithSelector; @@ -572,26 +472,6 @@ static EnumerableMethod() public static MethodInfo OfType => __ofType; public static MethodInfo OrderBy => __orderBy; public static MethodInfo OrderByDescending => __orderByDescending; - public static MethodInfo PercentileDecimal => __percentileDecimal; - public static MethodInfo PercentileDecimalWithSelector => __percentileDecimalWithSelector; - public static MethodInfo PercentileDouble => __percentileDouble; - public static MethodInfo PercentileDoubleWithSelector => __percentileDoubleWithSelector; - public static MethodInfo PercentileInt32 => __percentileInt32; - public static MethodInfo PercentileInt32WithSelector => __percentileInt32WithSelector; - public static MethodInfo PercentileInt64 => __percentileInt64; - public static MethodInfo PercentileInt64WithSelector => __percentileInt64WithSelector; - public static MethodInfo PercentileNullableDecimal => __percentileNullableDecimal; - public static MethodInfo PercentileNullableDecimalWithSelector => __percentileNullableDecimalWithSelector; - public static MethodInfo PercentileNullableDouble => __percentileNullableDouble; - public static MethodInfo PercentileNullableDoubleWithSelector => __percentileNullableDoubleWithSelector; - public static MethodInfo PercentileNullableInt32 => __percentileNullableInt32; - public static MethodInfo PercentileNullableInt32WithSelector => __percentileNullableInt32WithSelector; - public static MethodInfo PercentileNullableInt64 => __percentileNullableInt64; - public static MethodInfo PercentileNullableInt64WithSelector => __percentileNullableInt64WithSelector; - public static MethodInfo PercentileNullableSingle => __percentileNullableSingle; - public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; - public static MethodInfo PercentileSingle => __percentileSingle; - public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; public static MethodInfo Prepend => __prepend; public static MethodInfo Range => __range; public static MethodInfo Repeat => __repeat; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs index 3773f1f92bf..c10550024c3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs @@ -25,6 +25,46 @@ internal static class MongoEnumerableMethod private static readonly MethodInfo __allElements; private static readonly MethodInfo __allMatchingElements; private static readonly MethodInfo __firstMatchingElement; + private static readonly MethodInfo __medianDecimal; + private static readonly MethodInfo __medianDecimalWithSelector; + private static readonly MethodInfo __medianDouble; + private static readonly MethodInfo __medianDoubleWithSelector; + private static readonly MethodInfo __medianInt32; + private static readonly MethodInfo __medianInt32WithSelector; + private static readonly MethodInfo __medianInt64; + private static readonly MethodInfo __medianInt64WithSelector; + private static readonly MethodInfo __medianNullableDecimal; + private static readonly MethodInfo __medianNullableDecimalWithSelector; + private static readonly MethodInfo __medianNullableDouble; + private static readonly MethodInfo __medianNullableDoubleWithSelector; + private static readonly MethodInfo __medianNullableInt32; + private static readonly MethodInfo __medianNullableInt32WithSelector; + private static readonly MethodInfo __medianNullableInt64; + private static readonly MethodInfo __medianNullableInt64WithSelector; + private static readonly MethodInfo __medianNullableSingle; + private static readonly MethodInfo __medianNullableSingleWithSelector; + private static readonly MethodInfo __medianSingle; + private static readonly MethodInfo __medianSingleWithSelector; + private static readonly MethodInfo __percentileDecimal; + private static readonly MethodInfo __percentileDecimalWithSelector; + private static readonly MethodInfo __percentileDouble; + private static readonly MethodInfo __percentileDoubleWithSelector; + private static readonly MethodInfo __percentileInt32; + private static readonly MethodInfo __percentileInt32WithSelector; + private static readonly MethodInfo __percentileInt64; + private static readonly MethodInfo __percentileInt64WithSelector; + private static readonly MethodInfo __percentileNullableDecimal; + private static readonly MethodInfo __percentileNullableDecimalWithSelector; + private static readonly MethodInfo __percentileNullableDouble; + private static readonly MethodInfo __percentileNullableDoubleWithSelector; + private static readonly MethodInfo __percentileNullableInt32; + private static readonly MethodInfo __percentileNullableInt32WithSelector; + private static readonly MethodInfo __percentileNullableInt64; + private static readonly MethodInfo __percentileNullableInt64WithSelector; + private static readonly MethodInfo __percentileNullableSingle; + private static readonly MethodInfo __percentileNullableSingleWithSelector; + private static readonly MethodInfo __percentileSingle; + private static readonly MethodInfo __percentileSingleWithSelector; private static readonly MethodInfo __whereWithLimit; // static constructor @@ -33,6 +73,46 @@ static MongoEnumerableMethod() __allElements = ReflectionInfo.Method((IEnumerable source) => source.AllElements()); __allMatchingElements = ReflectionInfo.Method((IEnumerable source, string identifier) => source.AllMatchingElements(identifier)); __firstMatchingElement = ReflectionInfo.Method((IEnumerable source) => source.FirstMatchingElement()); + __medianDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __medianSingle = ReflectionInfo.Method((IEnumerable source) => source.Median()); + __medianSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Median(selector)); + __percentileDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableDecimal = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableDouble = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableInt32 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableInt64 = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileNullableSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); + __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __whereWithLimit = ReflectionInfo.Method((IEnumerable source, Func predicate, int limit) => source.Where(predicate, limit)); } @@ -40,6 +120,46 @@ static MongoEnumerableMethod() public static MethodInfo AllElements => __allElements; public static MethodInfo AllMatchingElements => __allMatchingElements; public static MethodInfo FirstMatchingElement => __firstMatchingElement; + public static MethodInfo MedianDecimal => __medianDecimal; + public static MethodInfo MedianDecimalWithSelector => __medianDecimalWithSelector; + public static MethodInfo MedianDouble => __medianDouble; + public static MethodInfo MedianDoubleWithSelector => __medianDoubleWithSelector; + public static MethodInfo MedianInt32 => __medianInt32; + public static MethodInfo MedianInt32WithSelector => __medianInt32WithSelector; + public static MethodInfo MedianInt64 => __medianInt64; + public static MethodInfo MedianInt64WithSelector => __medianInt64WithSelector; + public static MethodInfo MedianNullableDecimal => __medianNullableDecimal; + public static MethodInfo MedianNullableDecimalWithSelector => __medianNullableDecimalWithSelector; + public static MethodInfo MedianNullableDouble => __medianNullableDouble; + public static MethodInfo MedianNullableDoubleWithSelector => __medianNullableDoubleWithSelector; + public static MethodInfo MedianNullableInt32 => __medianNullableInt32; + public static MethodInfo MedianNullableInt32WithSelector => __medianNullableInt32WithSelector; + public static MethodInfo MedianNullableInt64 => __medianNullableInt64; + public static MethodInfo MedianNullableInt64WithSelector => __medianNullableInt64WithSelector; + public static MethodInfo MedianNullableSingle => __medianNullableSingle; + public static MethodInfo MedianNullableSingleWithSelector => __medianNullableSingleWithSelector; + public static MethodInfo MedianSingle => __medianSingle; + public static MethodInfo MedianSingleWithSelector => __medianSingleWithSelector; + public static MethodInfo PercentileDecimal => __percentileDecimal; + public static MethodInfo PercentileDecimalWithSelector => __percentileDecimalWithSelector; + public static MethodInfo PercentileDouble => __percentileDouble; + public static MethodInfo PercentileDoubleWithSelector => __percentileDoubleWithSelector; + public static MethodInfo PercentileInt32 => __percentileInt32; + public static MethodInfo PercentileInt32WithSelector => __percentileInt32WithSelector; + public static MethodInfo PercentileInt64 => __percentileInt64; + public static MethodInfo PercentileInt64WithSelector => __percentileInt64WithSelector; + public static MethodInfo PercentileNullableDecimal => __percentileNullableDecimal; + public static MethodInfo PercentileNullableDecimalWithSelector => __percentileNullableDecimalWithSelector; + public static MethodInfo PercentileNullableDouble => __percentileNullableDouble; + public static MethodInfo PercentileNullableDoubleWithSelector => __percentileNullableDoubleWithSelector; + public static MethodInfo PercentileNullableInt32 => __percentileNullableInt32; + public static MethodInfo PercentileNullableInt32WithSelector => __percentileNullableInt32WithSelector; + public static MethodInfo PercentileNullableInt64 => __percentileNullableInt64; + public static MethodInfo PercentileNullableInt64WithSelector => __percentileNullableInt64WithSelector; + public static MethodInfo PercentileNullableSingle => __percentileNullableSingle; + public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; + public static MethodInfo PercentileSingle => __percentileSingle; + public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; public static MethodInfo WhereWithLimit => __whereWithLimit; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs index b8029504b32..0baa8709c1d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -26,40 +26,40 @@ internal class MedianMethodToAggregationExpressionTranslator { private static readonly MethodInfo[] __medianMethods = [ - EnumerableMethod.MedianDecimal, - EnumerableMethod.MedianDecimalWithSelector, - EnumerableMethod.MedianDouble, - EnumerableMethod.MedianDoubleWithSelector, - EnumerableMethod.MedianInt32, - EnumerableMethod.MedianInt32WithSelector, - EnumerableMethod.MedianInt64, - EnumerableMethod.MedianInt64WithSelector, - EnumerableMethod.MedianNullableDecimal, - EnumerableMethod.MedianNullableDecimalWithSelector, - EnumerableMethod.MedianNullableDouble, - EnumerableMethod.MedianNullableDoubleWithSelector, - EnumerableMethod.MedianNullableInt32, - EnumerableMethod.MedianNullableInt32WithSelector, - EnumerableMethod.MedianNullableInt64, - EnumerableMethod.MedianNullableInt64WithSelector, - EnumerableMethod.MedianNullableSingle, - EnumerableMethod.MedianNullableSingleWithSelector, - EnumerableMethod.MedianSingle, - EnumerableMethod.MedianSingleWithSelector + MongoEnumerableMethod.MedianDecimal, + MongoEnumerableMethod.MedianDecimalWithSelector, + MongoEnumerableMethod.MedianDouble, + MongoEnumerableMethod.MedianDoubleWithSelector, + MongoEnumerableMethod.MedianInt32, + MongoEnumerableMethod.MedianInt32WithSelector, + MongoEnumerableMethod.MedianInt64, + MongoEnumerableMethod.MedianInt64WithSelector, + MongoEnumerableMethod.MedianNullableDecimal, + MongoEnumerableMethod.MedianNullableDecimalWithSelector, + MongoEnumerableMethod.MedianNullableDouble, + MongoEnumerableMethod.MedianNullableDoubleWithSelector, + MongoEnumerableMethod.MedianNullableInt32, + MongoEnumerableMethod.MedianNullableInt32WithSelector, + MongoEnumerableMethod.MedianNullableInt64, + MongoEnumerableMethod.MedianNullableInt64WithSelector, + MongoEnumerableMethod.MedianNullableSingle, + MongoEnumerableMethod.MedianNullableSingleWithSelector, + MongoEnumerableMethod.MedianSingle, + MongoEnumerableMethod.MedianSingleWithSelector ]; private static readonly MethodInfo[] __medianWithSelectorMethods = [ - EnumerableMethod.MedianDecimalWithSelector, - EnumerableMethod.MedianDoubleWithSelector, - EnumerableMethod.MedianInt32WithSelector, - EnumerableMethod.MedianInt64WithSelector, - EnumerableMethod.MedianNullableDecimalWithSelector, - EnumerableMethod.MedianNullableDoubleWithSelector, - EnumerableMethod.MedianNullableInt32WithSelector, - EnumerableMethod.MedianNullableInt64WithSelector, - EnumerableMethod.MedianNullableSingleWithSelector, - EnumerableMethod.MedianSingleWithSelector + MongoEnumerableMethod.MedianDecimalWithSelector, + MongoEnumerableMethod.MedianDoubleWithSelector, + MongoEnumerableMethod.MedianInt32WithSelector, + MongoEnumerableMethod.MedianInt64WithSelector, + MongoEnumerableMethod.MedianNullableDecimalWithSelector, + MongoEnumerableMethod.MedianNullableDoubleWithSelector, + MongoEnumerableMethod.MedianNullableInt32WithSelector, + MongoEnumerableMethod.MedianNullableInt64WithSelector, + MongoEnumerableMethod.MedianNullableSingleWithSelector, + MongoEnumerableMethod.MedianSingleWithSelector ]; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -77,10 +77,11 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC if (method.IsOneOf(__medianWithSelectorMethods)) { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorLambda = (LambdaExpression)arguments[1]; var selectorParameter = selectorLambda.Parameters[0]; - var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, sourceItemSerializer); var selectorContext = context.WithSymbol(selectorParameterSymbol); var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs index cdbbedb3cc6..216d89f1c49 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -15,7 +15,6 @@ using System.Linq.Expressions; using System.Reflection; -using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -27,40 +26,40 @@ internal class PercentileMethodToAggregationExpressionTranslator { private static readonly MethodInfo[] __percentileMethods = [ - EnumerableMethod.PercentileDecimal, - EnumerableMethod.PercentileDecimalWithSelector, - EnumerableMethod.PercentileDouble, - EnumerableMethod.PercentileDoubleWithSelector, - EnumerableMethod.PercentileInt32, - EnumerableMethod.PercentileInt32WithSelector, - EnumerableMethod.PercentileInt64, - EnumerableMethod.PercentileInt64WithSelector, - EnumerableMethod.PercentileNullableDecimal, - EnumerableMethod.PercentileNullableDecimalWithSelector, - EnumerableMethod.PercentileNullableDouble, - EnumerableMethod.PercentileNullableDoubleWithSelector, - EnumerableMethod.PercentileNullableInt32, - EnumerableMethod.PercentileNullableInt32WithSelector, - EnumerableMethod.PercentileNullableInt64, - EnumerableMethod.PercentileNullableInt64WithSelector, - EnumerableMethod.PercentileNullableSingle, - EnumerableMethod.PercentileNullableSingleWithSelector, - EnumerableMethod.PercentileSingle, - EnumerableMethod.PercentileSingleWithSelector + MongoEnumerableMethod.PercentileDecimal, + MongoEnumerableMethod.PercentileDecimalWithSelector, + MongoEnumerableMethod.PercentileDouble, + MongoEnumerableMethod.PercentileDoubleWithSelector, + MongoEnumerableMethod.PercentileInt32, + MongoEnumerableMethod.PercentileInt32WithSelector, + MongoEnumerableMethod.PercentileInt64, + MongoEnumerableMethod.PercentileInt64WithSelector, + MongoEnumerableMethod.PercentileNullableDecimal, + MongoEnumerableMethod.PercentileNullableDecimalWithSelector, + MongoEnumerableMethod.PercentileNullableDouble, + MongoEnumerableMethod.PercentileNullableDoubleWithSelector, + MongoEnumerableMethod.PercentileNullableInt32, + MongoEnumerableMethod.PercentileNullableInt32WithSelector, + MongoEnumerableMethod.PercentileNullableInt64, + MongoEnumerableMethod.PercentileNullableInt64WithSelector, + MongoEnumerableMethod.PercentileNullableSingle, + MongoEnumerableMethod.PercentileNullableSingleWithSelector, + MongoEnumerableMethod.PercentileSingle, + MongoEnumerableMethod.PercentileSingleWithSelector ]; private static readonly MethodInfo[] __percentileWithSelectorMethods = [ - EnumerableMethod.PercentileDecimalWithSelector, - EnumerableMethod.PercentileDoubleWithSelector, - EnumerableMethod.PercentileInt32WithSelector, - EnumerableMethod.PercentileInt64WithSelector, - EnumerableMethod.PercentileNullableDecimalWithSelector, - EnumerableMethod.PercentileNullableDoubleWithSelector, - EnumerableMethod.PercentileNullableInt32WithSelector, - EnumerableMethod.PercentileNullableInt64WithSelector, - EnumerableMethod.PercentileNullableSingleWithSelector, - EnumerableMethod.PercentileSingleWithSelector + MongoEnumerableMethod.PercentileDecimalWithSelector, + MongoEnumerableMethod.PercentileDoubleWithSelector, + MongoEnumerableMethod.PercentileInt32WithSelector, + MongoEnumerableMethod.PercentileInt64WithSelector, + MongoEnumerableMethod.PercentileNullableDecimalWithSelector, + MongoEnumerableMethod.PercentileNullableDoubleWithSelector, + MongoEnumerableMethod.PercentileNullableInt32WithSelector, + MongoEnumerableMethod.PercentileNullableInt64WithSelector, + MongoEnumerableMethod.PercentileNullableSingleWithSelector, + MongoEnumerableMethod.PercentileSingleWithSelector ]; public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -78,10 +77,11 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC if (method.IsOneOf(__percentileWithSelectorMethods)) { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + var selectorLambda = (LambdaExpression)arguments[1]; var selectorParameter = selectorLambda.Parameters[0]; - var selectorParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var selectorParameterSymbol = context.CreateSymbol(selectorParameter, selectorParameterSerializer); + var selectorParameterSymbol = context.CreateSymbol(selectorParameter, sourceItemSerializer); var selectorContext = context.WithSymbol(selectorParameterSymbol); var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body);