From 39e7dbd73fc0c9bccb3264f097607ea0d2aab50b Mon Sep 17 00:00:00 2001 From: Victor Frank Date: Mon, 30 Jun 2025 21:14:41 -0500 Subject: [PATCH] #144 add interval in Intersect --- interpreter/operator_dispatcher.go | 28 + interpreter/operator_interval.go | 634 ++++++++++++++++++++ parser/operators.go | 34 +- tests/enginetests/operator_interval_test.go | 260 ++++++++ tests/spectests/exclusions/exclusions.go | 6 +- 5 files changed, 953 insertions(+), 9 deletions(-) diff --git a/interpreter/operator_dispatcher.go b/interpreter/operator_dispatcher.go index af782da..5d88cba 100644 --- a/interpreter/operator_dispatcher.go +++ b/interpreter/operator_dispatcher.go @@ -1224,6 +1224,34 @@ func (i *interpreter) binaryOverloads(m model.IBinaryExpression) ([]convert.Over Operands: []types.IType{&types.List{ElementType: types.Any}, &types.List{ElementType: types.Any}}, Result: evalIntersect, }, + { + Operands: []types.IType{&types.Interval{PointType: types.Integer}, &types.Interval{PointType: types.Integer}}, + Result: i.evalIntersectIntervalInteger, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.Long}, &types.Interval{PointType: types.Long}}, + Result: i.evalIntersectIntervalLong, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.Decimal}, &types.Interval{PointType: types.Decimal}}, + Result: i.evalIntersectIntervalDecimal, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.Quantity}, &types.Interval{PointType: types.Quantity}}, + Result: i.evalIntersectIntervalQuantity, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.Date}, &types.Interval{PointType: types.Date}}, + Result: i.evalIntersectInterval, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.DateTime}, &types.Interval{PointType: types.DateTime}}, + Result: i.evalIntersectInterval, + }, + { + Operands: []types.IType{&types.Interval{PointType: types.Time}, &types.Interval{PointType: types.Time}}, + Result: i.evalIntersectInterval, + }, }, nil case *model.ProperlyIncludes: return []convert.Overload[evalBinarySignature]{ diff --git a/interpreter/operator_interval.go b/interpreter/operator_interval.go index e0461dc..d4a7cc6 100644 --- a/interpreter/operator_interval.go +++ b/interpreter/operator_interval.go @@ -641,3 +641,637 @@ func evalWidthInterval(m model.IUnaryExpression, intervalObj result.Value) (resu } return result.Value{}, fmt.Errorf("internal error - unsupported point type in evalWidthInterval: %v", start.RuntimeType()) } + +// intersect(left Interval, right Interval) Interval +// https://cql.hl7.org/09-b-cqlreference.html#intersect +// This function is used only for Date, DateTime, and Time intervals +func (i *interpreter) evalIntersectInterval(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Handle null inputs + if result.IsNull(lObj) || result.IsNull(rObj) { + return result.New(nil) + } + + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // This function only handles Date, DateTime, and Time intervals + return i.evalIntersectIntervalDateTime(m, lObj, rObj) +} + +// calculateNumeralIntersectionInt32 calculates the intersection of two int32 intervals +func (i *interpreter) calculateNumeralIntersectionInt32(lStart, lEnd, rStart, rEnd int32, lInterval, rInterval result.Interval) (result.Value, error) { + // Calculate intersection bounds + intersectionStart := maxInt32(lStart, rStart) + intersectionEnd := minInt32(lEnd, rEnd) + + // Check if there's no overlap + if compareNumeral(intersectionStart, intersectionEnd) == leftAfterRight { + return result.New(nil) + } + + // For intersection, the result bounds are always inclusive + // This is because we're using the effective start/end values from startAndEnd() + // which already account for the original inclusivity + startInclusive := true + endInclusive := true + + // Create result values + startVal, err := result.New(intersectionStart) + if err != nil { + return result.Value{}, err + } + endVal, err := result.New(intersectionEnd) + if err != nil { + return result.Value{}, err + } + + // Create the intersection interval + intersectionInterval := result.Interval{ + Low: startVal, + High: endVal, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// calculateNumeralIntersectionInt64 calculates the intersection of two int64 intervals +func (i *interpreter) calculateNumeralIntersectionInt64(lStart, lEnd, rStart, rEnd int64, lInterval, rInterval result.Interval) (result.Value, error) { + // Calculate intersection bounds + intersectionStart := maxInt64(lStart, rStart) + intersectionEnd := minInt64(lEnd, rEnd) + + // Check if there's no overlap + if compareNumeral(intersectionStart, intersectionEnd) == leftAfterRight { + return result.New(nil) + } + + // Calculate inclusivity for intersection bounds + // Start is inclusive if it matches an inclusive bound from either interval + startInclusive := true + if compareNumeral(intersectionStart, lStart) == leftEqualRight { + startInclusive = lInterval.LowInclusive + } + if compareNumeral(intersectionStart, rStart) == leftEqualRight { + startInclusive = startInclusive && rInterval.LowInclusive + } + + // End is inclusive if it matches an inclusive bound from either interval + endInclusive := true + if compareNumeral(intersectionEnd, lEnd) == leftEqualRight { + endInclusive = lInterval.HighInclusive + } + if compareNumeral(intersectionEnd, rEnd) == leftEqualRight { + endInclusive = endInclusive && rInterval.HighInclusive + } + + // Create result values + startVal, err := result.New(intersectionStart) + if err != nil { + return result.Value{}, err + } + endVal, err := result.New(intersectionEnd) + if err != nil { + return result.Value{}, err + } + + // Create the intersection interval + intersectionInterval := result.Interval{ + Low: startVal, + High: endVal, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// calculateNumeralIntersectionFloat64 calculates the intersection of two float64 intervals +func (i *interpreter) calculateNumeralIntersectionFloat64(lStart, lEnd, rStart, rEnd float64, lInterval, rInterval result.Interval) (result.Value, error) { + // Calculate intersection bounds + intersectionStart := maxFloat64(lStart, rStart) + intersectionEnd := minFloat64(lEnd, rEnd) + + // Check if there's no overlap + if compareNumeral(intersectionStart, intersectionEnd) == leftAfterRight { + return result.New(nil) + } + + // Create intersection interval with appropriate inclusivity + startInclusive := (compareNumeral(intersectionStart, lStart) == leftEqualRight && lInterval.LowInclusive) || + (compareNumeral(intersectionStart, rStart) == leftEqualRight && rInterval.LowInclusive) || + (compareNumeral(intersectionStart, lStart) == leftAfterRight && compareNumeral(intersectionStart, lEnd) == leftBeforeRight) || + (compareNumeral(intersectionStart, rStart) == leftAfterRight && compareNumeral(intersectionStart, rEnd) == leftBeforeRight) + + endInclusive := (compareNumeral(intersectionEnd, lEnd) == leftEqualRight && lInterval.HighInclusive) || + (compareNumeral(intersectionEnd, rEnd) == leftEqualRight && rInterval.HighInclusive) || + (compareNumeral(intersectionEnd, lStart) == leftAfterRight && compareNumeral(intersectionEnd, lEnd) == leftBeforeRight) || + (compareNumeral(intersectionEnd, rStart) == leftAfterRight && compareNumeral(intersectionEnd, rEnd) == leftBeforeRight) + + // Create result values + startVal, err := result.New(intersectionStart) + if err != nil { + return result.Value{}, err + } + endVal, err := result.New(intersectionEnd) + if err != nil { + return result.Value{}, err + } + + // Create the intersection interval + intersectionInterval := result.Interval{ + Low: startVal, + High: endVal, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// calculateNumeralIntersectionQuantity calculates the intersection of two Quantity intervals +func (i *interpreter) calculateNumeralIntersectionQuantity(lStart, lEnd, rStart, rEnd result.Quantity, lInterval, rInterval result.Interval) (result.Value, error) { + // Calculate intersection bounds + intersectionStart := maxFloat64(lStart.Value, rStart.Value) + intersectionEnd := minFloat64(lEnd.Value, rEnd.Value) + + // Check if there's no overlap + if compareNumeral(intersectionStart, intersectionEnd) == leftAfterRight { + return result.New(nil) + } + + // Calculate inclusivity for intersection bounds + // Start is inclusive if it matches an inclusive bound from either interval + startInclusive := true + if compareNumeral(intersectionStart, lStart.Value) == leftEqualRight { + startInclusive = lInterval.LowInclusive + } + if compareNumeral(intersectionStart, rStart.Value) == leftEqualRight { + startInclusive = startInclusive && rInterval.LowInclusive + } + + // End is inclusive if it matches an inclusive bound from either interval + endInclusive := true + if compareNumeral(intersectionEnd, lEnd.Value) == leftEqualRight { + endInclusive = lInterval.HighInclusive + } + if compareNumeral(intersectionEnd, rEnd.Value) == leftEqualRight { + endInclusive = endInclusive && rInterval.HighInclusive + } + + // Create result values with Quantity type + startVal, err := result.New(result.Quantity{Value: intersectionStart, Unit: lStart.Unit}) + if err != nil { + return result.Value{}, err + } + endVal, err := result.New(result.Quantity{Value: intersectionEnd, Unit: lStart.Unit}) + if err != nil { + return result.Value{}, err + } + + // Create the intersection interval + intersectionInterval := result.Interval{ + Low: startVal, + High: endVal, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// evalIntersectIntervalDateTime handles intersection for date/time interval types +func (i *interpreter) evalIntersectIntervalDateTime(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // Get interval metadata for result construction + lInterval, _ := result.ToInterval(lObj) + + // Handle different date/time types + switch lInterval.StaticType.PointType { + case types.Date: + return i.evalIntersectIntervalDate(lStart, lEnd, rStart, rEnd, lInterval) + case types.DateTime: + return i.evalIntersectIntervalDateTimeType(lStart, lEnd, rStart, rEnd, lInterval) + case types.Time: + return i.evalIntersectIntervalTime(lStart, lEnd, rStart, rEnd, lInterval) + default: + return result.Value{}, fmt.Errorf("internal error - unsupported date/time type in evalIntersectIntervalDateTime: %v", lInterval.StaticType.PointType) + } +} + +// evalIntersectIntervalDate handles intersection for Date intervals +func (i *interpreter) evalIntersectIntervalDate(lStart, lEnd, rStart, rEnd result.Value, lInterval result.Interval) (result.Value, error) { + // Convert to DateTime for comparison but preserve Date type in result + lStartDT, lEndDT, err := applyToValues(lStart, lEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + rStartDT, rEndDT, err := applyToValues(rStart, rEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + + // Calculate intersection bounds using time.Time comparison + var intersectionStartDT, intersectionEndDT result.DateTime + var intersectionStart, intersectionEnd result.Value + + if lStartDT.Date.After(rStartDT.Date) { + intersectionStartDT = lStartDT + intersectionStart = lStart + } else { + intersectionStartDT = rStartDT + intersectionStart = rStart + } + + if lEndDT.Date.Before(rEndDT.Date) { + intersectionEndDT = lEndDT + intersectionEnd = lEnd + } else { + intersectionEndDT = rEndDT + intersectionEnd = rEnd + } + + // Check if there's no overlap + if intersectionStartDT.Date.After(intersectionEndDT.Date) { + return result.New(nil) + } + + // Determine inclusivity for intersection bounds + startInclusive := (intersectionStartDT.Date.Equal(lStartDT.Date) && lInterval.LowInclusive) || + (intersectionStartDT.Date.Equal(rStartDT.Date) && lInterval.LowInclusive) || + (intersectionStartDT.Date.After(lStartDT.Date) && intersectionStartDT.Date.Before(lEndDT.Date)) || + (intersectionStartDT.Date.After(rStartDT.Date) && intersectionStartDT.Date.Before(rEndDT.Date)) + + endInclusive := (intersectionEndDT.Date.Equal(lEndDT.Date) && lInterval.HighInclusive) || + (intersectionEndDT.Date.Equal(rEndDT.Date) && lInterval.HighInclusive) || + (intersectionEndDT.Date.After(lStartDT.Date) && intersectionEndDT.Date.Before(lEndDT.Date)) || + (intersectionEndDT.Date.After(rStartDT.Date) && intersectionEndDT.Date.Before(rEndDT.Date)) + + // Create the intersection interval with Date values + intersectionInterval := result.Interval{ + Low: intersectionStart, + High: intersectionEnd, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// evalIntersectIntervalDateTimeType handles intersection for DateTime intervals +func (i *interpreter) evalIntersectIntervalDateTimeType(lStart, lEnd, rStart, rEnd result.Value, lInterval result.Interval) (result.Value, error) { + // Convert to DateTime for comparison + lStartDT, lEndDT, err := applyToValues(lStart, lEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + rStartDT, rEndDT, err := applyToValues(rStart, rEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + + // Calculate intersection bounds using time.Time comparison + var intersectionStart, intersectionEnd result.DateTime + if lStartDT.Date.After(rStartDT.Date) { + intersectionStart = lStartDT + } else { + intersectionStart = rStartDT + } + + if lEndDT.Date.Before(rEndDT.Date) { + intersectionEnd = lEndDT + } else { + intersectionEnd = rEndDT + } + + // Check if there's no overlap + if intersectionStart.Date.After(intersectionEnd.Date) { + return result.New(nil) + } + + // Determine inclusivity for intersection bounds + startInclusive := (intersectionStart.Date.Equal(lStartDT.Date) && lInterval.LowInclusive) || + (intersectionStart.Date.Equal(rStartDT.Date) && lInterval.LowInclusive) || + (intersectionStart.Date.After(lStartDT.Date) && intersectionStart.Date.Before(lEndDT.Date)) || + (intersectionStart.Date.After(rStartDT.Date) && intersectionStart.Date.Before(rEndDT.Date)) + + endInclusive := (intersectionEnd.Date.Equal(lEndDT.Date) && lInterval.HighInclusive) || + (intersectionEnd.Date.Equal(rEndDT.Date) && lInterval.HighInclusive) || + (intersectionEnd.Date.After(lStartDT.Date) && intersectionEnd.Date.Before(lEndDT.Date)) || + (intersectionEnd.Date.After(rStartDT.Date) && intersectionEnd.Date.Before(rEndDT.Date)) + + // Create result values + startVal, err := result.New(intersectionStart) + if err != nil { + return result.Value{}, err + } + endVal, err := result.New(intersectionEnd) + if err != nil { + return result.Value{}, err + } + + // Create the intersection interval + intersectionInterval := result.Interval{ + Low: startVal, + High: endVal, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + +// evalIntersectIntervalTime handles intersection for Time intervals +func (i *interpreter) evalIntersectIntervalTime(lStart, lEnd, rStart, rEnd result.Value, lInterval result.Interval) (result.Value, error) { + // Convert to DateTime for comparison but preserve Time type in result + lStartDT, lEndDT, err := applyToValues(lStart, lEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + rStartDT, rEndDT, err := applyToValues(rStart, rEnd, result.ToDateTime) + if err != nil { + return result.Value{}, err + } + + // Calculate intersection bounds using time.Time comparison + var intersectionStartDT, intersectionEndDT result.DateTime + var intersectionStart, intersectionEnd result.Value + + if lStartDT.Date.After(rStartDT.Date) { + intersectionStartDT = lStartDT + intersectionStart = lStart + } else { + intersectionStartDT = rStartDT + intersectionStart = rStart + } + + if lEndDT.Date.Before(rEndDT.Date) { + intersectionEndDT = lEndDT + intersectionEnd = lEnd + } else { + intersectionEndDT = rEndDT + intersectionEnd = rEnd + } + + // Check if there's no overlap + if intersectionStartDT.Date.After(intersectionEndDT.Date) { + return result.New(nil) + } + + // Determine inclusivity for intersection bounds + startInclusive := (intersectionStartDT.Date.Equal(lStartDT.Date) && lInterval.LowInclusive) || + (intersectionStartDT.Date.Equal(rStartDT.Date) && lInterval.LowInclusive) || + (intersectionStartDT.Date.After(lStartDT.Date) && intersectionStartDT.Date.Before(lEndDT.Date)) || + (intersectionStartDT.Date.After(rStartDT.Date) && intersectionStartDT.Date.Before(rEndDT.Date)) + + endInclusive := (intersectionEndDT.Date.Equal(lEndDT.Date) && lInterval.HighInclusive) || + (intersectionEndDT.Date.Equal(rEndDT.Date) && lInterval.HighInclusive) || + (intersectionEndDT.Date.After(lStartDT.Date) && intersectionEndDT.Date.Before(lEndDT.Date)) || + (intersectionEndDT.Date.After(rStartDT.Date) && intersectionEndDT.Date.Before(rEndDT.Date)) + + // Create the intersection interval with Time values + intersectionInterval := result.Interval{ + Low: intersectionStart, + High: intersectionEnd, + LowInclusive: startInclusive, + HighInclusive: endInclusive, + StaticType: lInterval.StaticType, + } + + return result.New(intersectionInterval) +} + + +// Helper functions for min/max calculations +func maxInt32(a, b int32) int32 { + if a > b { + return a + } + return b +} + +func minInt32(a, b int32) int32 { + if a < b { + return a + } + return b +} + +func maxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func minInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func maxFloat64(a, b float64) float64 { + if a > b { + return a + } + return b +} + +func minFloat64(a, b float64) float64 { + if a < b { + return a + } + return b +} + +// Type-specific intersect functions for dispatcher +func (i *interpreter) evalIntersectIntervalInteger(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Handle null inputs + if result.IsNull(lObj) || result.IsNull(rObj) { + return result.New(nil) + } + + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // Get interval metadata for result construction + lInterval, _ := result.ToInterval(lObj) + rInterval, _ := result.ToInterval(rObj) + + // Convert to int32 values + lStartVal, lEndVal, err := applyToValues(lStart, lEnd, result.ToInt32) + if err != nil { + return result.Value{}, err + } + rStartVal, rEndVal, err := applyToValues(rStart, rEnd, result.ToInt32) + if err != nil { + return result.Value{}, err + } + + return i.calculateNumeralIntersectionInt32(lStartVal, lEndVal, rStartVal, rEndVal, lInterval, rInterval) +} + +func (i *interpreter) evalIntersectIntervalLong(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Handle null inputs + if result.IsNull(lObj) || result.IsNull(rObj) { + return result.New(nil) + } + + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // Get interval metadata for result construction + lInterval, _ := result.ToInterval(lObj) + rInterval, _ := result.ToInterval(rObj) + + // Convert to int64 values + lStartVal, lEndVal, err := applyToValues(lStart, lEnd, result.ToInt64) + if err != nil { + return result.Value{}, err + } + rStartVal, rEndVal, err := applyToValues(rStart, rEnd, result.ToInt64) + if err != nil { + return result.Value{}, err + } + + return i.calculateNumeralIntersectionInt64(lStartVal, lEndVal, rStartVal, rEndVal, lInterval, rInterval) +} + +func (i *interpreter) evalIntersectIntervalDecimal(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Handle null inputs + if result.IsNull(lObj) || result.IsNull(rObj) { + return result.New(nil) + } + + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // Get interval metadata for result construction + lInterval, _ := result.ToInterval(lObj) + rInterval, _ := result.ToInterval(rObj) + + // Convert to float64 values + lStartVal, lEndVal, err := applyToValues(lStart, lEnd, result.ToFloat64) + if err != nil { + return result.Value{}, err + } + rStartVal, rEndVal, err := applyToValues(rStart, rEnd, result.ToFloat64) + if err != nil { + return result.Value{}, err + } + + return i.calculateNumeralIntersectionFloat64(lStartVal, lEndVal, rStartVal, rEndVal, lInterval, rInterval) +} + +func (i *interpreter) evalIntersectIntervalQuantity(m model.IBinaryExpression, lObj, rObj result.Value) (result.Value, error) { + // Handle null inputs + if result.IsNull(lObj) || result.IsNull(rObj) { + return result.New(nil) + } + + // Get start and end bounds for both intervals + lStart, lEnd, err := startAndEnd(lObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + rStart, rEnd, err := startAndEnd(rObj, &i.evaluationTimestamp) + if err != nil { + return result.Value{}, err + } + + // If any bound is null, return null + if result.IsNull(lStart) || result.IsNull(lEnd) || result.IsNull(rStart) || result.IsNull(rEnd) { + return result.New(nil) + } + + // Get interval metadata for result construction + lInterval, _ := result.ToInterval(lObj) + rInterval, _ := result.ToInterval(rObj) + + // Convert to Quantity values + lStartVal, lEndVal, err := applyToValues(lStart, lEnd, result.ToQuantity) + if err != nil { + return result.Value{}, err + } + rStartVal, rEndVal, err := applyToValues(rStart, rEnd, result.ToQuantity) + if err != nil { + return result.Value{}, err + } + + // Check unit compatibility + if lStartVal.Unit != rStartVal.Unit { + return result.Value{}, fmt.Errorf("intersect operator received Quantities with differing unit values, unit conversion is not currently supported, got: %v, %v", lStartVal.Unit, rStartVal.Unit) + } + + return i.calculateNumeralIntersectionQuantity(lStartVal, lEndVal, rStartVal, rEndVal, lInterval, rInterval) +} diff --git a/parser/operators.go b/parser/operators.go index 1f616c2..b72fe18 100644 --- a/parser/operators.go +++ b/parser/operators.go @@ -81,13 +81,22 @@ func (v *visitor) resolveFunction(libraryName, funcName string, operands []model // For Except the left side is the result type. t.Expression = model.ResultType(resolved.WrappedOperands[0].GetResultType()) case *model.Intersect: - listTypeLeft := resolved.WrappedOperands[0].GetResultType().(*types.List) - listTypeRight := resolved.WrappedOperands[1].GetResultType().(*types.List) - listElemType, err := convert.Intersect(listTypeLeft.ElementType, listTypeRight.ElementType) - if err != nil { - return nil, err + leftType := resolved.WrappedOperands[0].GetResultType() + rightType := resolved.WrappedOperands[1].GetResultType() + + // Check if this is a list intersect or interval intersect + if listTypeLeft, ok := leftType.(*types.List); ok { + // List intersect + listTypeRight := rightType.(*types.List) + listElemType, err := convert.Intersect(listTypeLeft.ElementType, listTypeRight.ElementType) + if err != nil { + return nil, err + } + t.Expression = model.ResultType(&types.List{ElementType: listElemType}) + } else if intervalTypeLeft, ok := leftType.(*types.Interval); ok { + // Interval intersect - result type is the same interval type + t.Expression = model.ResultType(intervalTypeLeft) } - t.Expression = model.ResultType(&types.List{ElementType: listElemType}) case *model.Avg: listType := resolved.WrappedOperands[0].GetResultType().(*types.List) t.Expression = model.ResultType(listType.ElementType) @@ -1831,8 +1840,17 @@ func (p *Parser) loadSystemOperators() error { }, }, { - name: "Intersect", - operands: [][]types.IType{{&types.List{ElementType: types.Any}, &types.List{ElementType: types.Any}}}, + name: "Intersect", + operands: [][]types.IType{ + {&types.List{ElementType: types.Any}, &types.List{ElementType: types.Any}}, + {&types.Interval{PointType: types.Integer}, &types.Interval{PointType: types.Integer}}, + {&types.Interval{PointType: types.Long}, &types.Interval{PointType: types.Long}}, + {&types.Interval{PointType: types.Decimal}, &types.Interval{PointType: types.Decimal}}, + {&types.Interval{PointType: types.Quantity}, &types.Interval{PointType: types.Quantity}}, + {&types.Interval{PointType: types.Date}, &types.Interval{PointType: types.Date}}, + {&types.Interval{PointType: types.DateTime}, &types.Interval{PointType: types.DateTime}}, + {&types.Interval{PointType: types.Time}, &types.Interval{PointType: types.Time}}, + }, model: func() model.IExpression { return &model.Intersect{ BinaryExpression: &model.BinaryExpression{}, diff --git a/tests/enginetests/operator_interval_test.go b/tests/enginetests/operator_interval_test.go index 595a902..f80b305 100644 --- a/tests/enginetests/operator_interval_test.go +++ b/tests/enginetests/operator_interval_test.go @@ -125,6 +125,266 @@ func TestEnd(t *testing.T) { } } +func TestIntersectInterval(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantResult result.Value + }{ + // Basic integer interval tests + { + name: "Integer intervals with overlap", + cql: "Interval[1, 5] intersect Interval[3, 7]", + wantModel: &model.Intersect{ + BinaryExpression: &model.BinaryExpression{ + Expression: model.ResultType(&types.Interval{PointType: types.Integer}), + Operands: []model.IExpression{ + &model.Interval{ + Low: model.NewLiteral("1", types.Integer), + High: model.NewLiteral("5", types.Integer), + LowInclusive: true, + HighInclusive: true, + Expression: model.ResultType(&types.Interval{PointType: types.Integer}), + }, + &model.Interval{ + Low: model.NewLiteral("3", types.Integer), + High: model.NewLiteral("7", types.Integer), + LowInclusive: true, + HighInclusive: true, + Expression: model.ResultType(&types.Interval{PointType: types.Integer}), + }, + }, + }, + }, + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(3)), + High: newOrFatal(t, int32(5)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + { + name: "Integer intervals with no overlap", + cql: "Interval[1, 3] intersect Interval[5, 7]", + wantResult: newOrFatal(t, nil), + }, + { + name: "Integer intervals touching at boundary", + cql: "Interval[1, 3] intersect Interval[3, 7]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(3)), + High: newOrFatal(t, int32(3)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + { + name: "Integer intervals with exclusive bounds", + cql: "Interval[1, 3) intersect Interval(2, 7]", + wantResult: newOrFatal(t, nil), + }, + { + name: "Identical integer intervals", + cql: "Interval[1, 5] intersect Interval[1, 5]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(1)), + High: newOrFatal(t, int32(5)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + // Null handling tests + { + name: "Left interval is null", + cql: "null as Interval intersect Interval[3, 7]", + wantResult: newOrFatal(t, nil), + }, + { + name: "Right interval is null", + cql: "Interval[1, 5] intersect null as Interval", + wantResult: newOrFatal(t, nil), + }, + { + name: "Both intervals are null", + cql: "null as Interval intersect null as Interval", + wantResult: newOrFatal(t, nil), + }, + // Decimal interval tests + { + name: "Decimal intervals with overlap", + cql: "Interval[1.5, 5.5] intersect Interval[3.0, 7.0]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, 3.0), + High: newOrFatal(t, 5.5), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Decimal}, + }), + }, + { + name: "Decimal intervals with no overlap", + cql: "Interval[1.0, 2.5] intersect Interval[3.0, 4.5]", + wantResult: newOrFatal(t, nil), + }, + // Long interval tests + { + name: "Long intervals with overlap", + cql: "Interval[1L, 5L] intersect Interval[3L, 7L]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int64(3)), + High: newOrFatal(t, int64(5)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Long}, + }), + }, + // Quantity interval tests + { + name: "Quantity intervals with overlap", + cql: "Interval[1'cm', 5'cm'] intersect Interval[3'cm', 7'cm']", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, result.Quantity{Value: 3, Unit: "cm"}), + High: newOrFatal(t, result.Quantity{Value: 5, Unit: "cm"}), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Quantity}, + }), + }, + // Date interval tests + { + name: "Date intervals with overlap", + cql: "Interval[@2020-01-01, @2020-06-01] intersect Interval[@2020-03-01, @2020-09-01]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, result.Date{Date: time.Date(2020, time.March, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + High: newOrFatal(t, result.Date{Date: time.Date(2020, time.June, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Date}, + }), + }, + { + name: "Date intervals with no overlap", + cql: "Interval[@2020-01-01, @2020-02-01] intersect Interval[@2020-03-01, @2020-04-01]", + wantResult: newOrFatal(t, nil), + }, + // DateTime interval tests + { + name: "DateTime intervals with overlap", + cql: "Interval[@2020-01-01T10:00:00, @2020-01-01T15:00:00] intersect Interval[@2020-01-01T12:00:00, @2020-01-01T18:00:00]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, result.DateTime{Date: time.Date(2020, time.January, 1, 12, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.SECOND}), + High: newOrFatal(t, result.DateTime{Date: time.Date(2020, time.January, 1, 15, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.SECOND}), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.DateTime}, + }), + }, + // Edge cases + { + name: "First interval completely contains second", + cql: "Interval[1, 10] intersect Interval[3, 7]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(3)), + High: newOrFatal(t, int32(7)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + { + name: "Second interval completely contains first", + cql: "Interval[3, 7] intersect Interval[1, 10]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(3)), + High: newOrFatal(t, int32(7)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + { + name: "Intervals with mixed inclusivity", + cql: "Interval[1, 5) intersect Interval(3, 7]", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(4)), + High: newOrFatal(t, int32(4)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + // Functional syntax + { + name: "Functional syntax", + cql: "Intersect(Interval[1, 5], Interval[3, 7])", + wantResult: newOrFatal(t, result.Interval{ + Low: newOrFatal(t, int32(3)), + High: newOrFatal(t, int32(5)), + LowInclusive: true, + HighInclusive: true, + StaticType: &types.Interval{PointType: types.Integer}, + }), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +} + +func TestIntersectInterval_Error(t *testing.T) { + tests := []struct { + name string + cql string + wantEvalErrContains string + }{ + { + name: "Quantity intervals with different units", + cql: "Interval[1'cm', 5'cm'] intersect Interval[3'm', 7'm']", + wantEvalErrContains: "intersect operator received Quantities with differing unit values", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + + _, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err == nil { + t.Fatal("Eval succeeded, wanted error") + } + if !strings.Contains(err.Error(), tc.wantEvalErrContains) { + t.Errorf("Unexpected evaluation error contents. got (%v), want contains (%v)", err.Error(), tc.wantEvalErrContains) + } + }) + } +} + func TestStart(t *testing.T) { tests := []struct { name string diff --git a/tests/spectests/exclusions/exclusions.go b/tests/spectests/exclusions/exclusions.go index 50d255e..34e2cb6 100644 --- a/tests/spectests/exclusions/exclusions.go +++ b/tests/spectests/exclusions/exclusions.go @@ -227,7 +227,6 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { "Ends", "Except", "Includes", - "Intersect", "Meets", "MeetsBefore", "MeetsAfter", @@ -311,6 +310,11 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { // TODO: b/342064453 - Ambiguous match. "TestEqualNull", "TestInNullBoundaries", + "TestIntersectNull", + "TestIntersectNull1", + "TestIntersectNull2", + "TestIntersectNull3", + "TestIntersectNull4", }, }, "CqlListOperatorsTest.xml": {