diff --git a/interpreter/operator_aggregate.go b/interpreter/operator_aggregate.go index 14695ca..e907f04 100644 --- a/interpreter/operator_aggregate.go +++ b/interpreter/operator_aggregate.go @@ -1038,32 +1038,10 @@ func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, oper return result.Value{}, err } - countValue, err := i.evalCount(m, operand) - if err != nil { - return result.Value{}, err - } - if result.IsNull(countValue) { - return result.New(nil) - } - count, err := result.ToInt32(countValue) - if err != nil { - return result.Value{}, err - } - if count == 0 { - return result.New(nil) - } - meanValue, err := i.evalAvg(m, operand) - if err != nil { - return result.Value{}, err - } - if result.IsNull(meanValue) { - return result.New(nil) - } - mean, err := result.ToFloat64(meanValue) - if err != nil { - return result.Value{}, err - } - var sum float64 + var count float64 + var mean float64 + var m2 float64 + for _, elem := range l { if result.IsNull(elem) { continue @@ -1072,10 +1050,20 @@ func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, oper if err != nil { return result.Value{}, err } - sum += (v - mean) * (v - mean) + + count++ + delta := v - mean + mean += delta / count + delta2 := v - mean + m2 += delta * delta2 + } + + if count == 0 { + return result.New(nil) } + // Round to 8 decimal places to match CQL expected precision - stdDev := math.Sqrt(sum / float64(count)) + stdDev := math.Sqrt(m2 / count) roundedStdDev := math.Round(stdDev*100000000) / 100000000 return result.New(roundedStdDev) } diff --git a/tests/enginetests/benchmark_stddev_test.go b/tests/enginetests/benchmark_stddev_test.go new file mode 100644 index 0000000..6d78bef --- /dev/null +++ b/tests/enginetests/benchmark_stddev_test.go @@ -0,0 +1,63 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package enginetests + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/google/cql/interpreter" + "github.com/google/cql/parser" + "github.com/google/cql/result" +) + +func BenchmarkPopulationStdDev(b *testing.B) { + // Create a large list of decimals + count := 10000 + values := make([]string, count) + for i := 0; i < count; i++ { + values[i] = fmt.Sprintf("%d.0", i%100) + } + listCQL := "{" + strings.Join(values, ", ") + "}" + cql := "PopulationStdDev(" + listCQL + ")" + + p := newFHIRParser(b) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(b, cql), parser.Config{}) + if err != nil { + b.Fatalf("Parse Libraries returned unexpected error: %v", err) + } + + config := interpreter.Config{ + DataModels: p.DataModel(), + Retriever: BuildRetriever(b), + Terminology: buildTerminologyProvider(b), + EvaluationTimestamp: defaultEvalTimestamp, + ReturnPrivateDefs: true, + } + + b.ResetTimer() + b.Run("LargeList", func(b *testing.B) { + var force result.Libraries + for n := 0; n < b.N; n++ { + force, err = interpreter.Eval(context.Background(), parsedLibs, config) + if err != nil { + b.Fatalf("Eval returned unexpected error: %v", err) + } + } + forceBenchResult = force + }) +} diff --git a/tests/enginetests/operator_aggregate_test.go b/tests/enginetests/operator_aggregate_test.go index 6123dff..e45eace 100644 --- a/tests/enginetests/operator_aggregate_test.go +++ b/tests/enginetests/operator_aggregate_test.go @@ -1387,3 +1387,85 @@ func TestMode(t *testing.T) { }) } } + +func TestPopulationStdDev(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantResult result.Value + }{ + { + name: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})", + cql: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})", + wantModel: &model.PopulationStdDev{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewList([]string{"1.0", "2.0", "3.0", "4.0", "5.0"}, types.Decimal), + Expression: model.ResultType(types.Decimal), + }, + }, + wantResult: newOrFatal(t, 1.41421356), + }, + { + name: "PopulationStdDev with unordered decimal list", + cql: "PopulationStdDev({5.0, 2.0, 1.0, 4.0, 3.0})", + wantResult: newOrFatal(t, 1.41421356), + }, + { + name: "PopulationStdDev with all identical values", + cql: "PopulationStdDev({3.0, 3.0, 3.0, 3.0})", + wantResult: newOrFatal(t, 0.0), + }, + { + name: "PopulationStdDev with null input", + cql: "PopulationStdDev(null as List)", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev with empty list", + cql: "PopulationStdDev({} as List)", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev with single value", + cql: "PopulationStdDev({5.0})", + wantResult: newOrFatal(t, 0.0), + }, + { + name: "PopulationStdDev with null values in list", + cql: "PopulationStdDev({1.0, null, 3.0, null, 5.0})", + wantResult: newOrFatal(t, 1.63299316), + }, + { + name: "PopulationStdDev with all null values", + cql: "PopulationStdDev({null, null, null} as List)", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev with quantities", + cql: "PopulationStdDev({1.0 'g', 2.0 'g', 3.0 'g', 4.0 'g', 5.0 'g'})", + wantResult: newOrFatal(t, result.Quantity{Value: 1.4142135623730951, Unit: "g"}), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +}