From 88e13468f6e7f12cd569bb036301a014f600cb00 Mon Sep 17 00:00:00 2001 From: rubywtl Date: Wed, 3 Sep 2025 14:44:49 -0700 Subject: [PATCH] allow querier to execute logical plan fragments following child-root-execution model Signed-off-by: rubywtl --- integration/distributed_exec_test.go | 177 ++++++++++ pkg/api/api.go | 6 + pkg/api/handlers.go | 7 +- pkg/api/handlers_test.go | 2 +- pkg/api/queryapi/query_api.go | 133 ++++++-- pkg/api/queryapi/query_api_test.go | 32 +- pkg/cortex/modules.go | 29 ++ pkg/distributed_execution/id.go | 32 ++ pkg/distributed_execution/id_test.go | 97 ++++++ .../querier_service_client.go | 129 +++++++ .../querier_service_client_test.go | 317 ++++++++++++++++++ .../querier_service_server.go | 311 +++++++++++++++++ .../querier_service_server_test.go | 206 ++++++++++++ pkg/distributed_execution/query_tracker.go | 127 +++++++ pkg/distributed_execution/remote_node.go | 261 ++++++++++++++ pkg/querier/worker/scheduler_processor.go | 11 +- 16 files changed, 1840 insertions(+), 37 deletions(-) create mode 100644 integration/distributed_exec_test.go create mode 100644 pkg/distributed_execution/id.go create mode 100644 pkg/distributed_execution/id_test.go create mode 100644 pkg/distributed_execution/querier_service_client.go create mode 100644 pkg/distributed_execution/querier_service_client_test.go create mode 100644 pkg/distributed_execution/querier_service_server.go create mode 100644 pkg/distributed_execution/querier_service_server_test.go create mode 100644 pkg/distributed_execution/query_tracker.go diff --git a/integration/distributed_exec_test.go b/integration/distributed_exec_test.go new file mode 100644 index 00000000000..a7ab5dadf44 --- /dev/null +++ b/integration/distributed_exec_test.go @@ -0,0 +1,177 @@ +//go:build integration_query_fuzz +// +build integration_query_fuzz + +package integration + +import ( + "context" + "math/rand" + "path" + "strconv" + "strings" + "testing" + "time" + + "github.com/cortexproject/promqlsmith" + "github.com/prometheus/prometheus/model/labels" + "github.com/prometheus/prometheus/prompb" + "github.com/stretchr/testify/require" + + "github.com/cortexproject/cortex/integration/e2e" + e2edb "github.com/cortexproject/cortex/integration/e2e/db" + "github.com/cortexproject/cortex/integration/e2ecortex" + "github.com/cortexproject/cortex/pkg/storage/tsdb" +) + +func TestDistributedExecutionFuzz(t *testing.T) { + s, err := e2e.NewScenario(networkName) + require.NoError(t, err) + defer s.Close() + + // start dependencies. + consul1 := e2edb.NewConsulWithName("consul1") + consul2 := e2edb.NewConsulWithName("consul2") + require.NoError(t, s.StartAndWaitReady(consul1, consul2)) + + flags := mergeFlags( + AlertmanagerLocalFlags(), + map[string]string{ + "-store.engine": blocksStorageEngine, + "-blocks-storage.backend": "filesystem", + "-blocks-storage.tsdb.head-compaction-interval": "4m", + "-blocks-storage.tsdb.block-ranges-period": "2h", + "-blocks-storage.tsdb.ship-interval": "1h", + "-blocks-storage.bucket-store.sync-interval": "15m", + "-blocks-storage.tsdb.retention-period": "2h", + "-blocks-storage.bucket-store.index-cache.backend": tsdb.IndexCacheBackendInMemory, + "-querier.query-store-for-labels-enabled": "true", + // Ingester. + "-ring.store": "consul", + "-consul.hostname": consul1.NetworkHTTPEndpoint(), + // Distributor. + "-distributor.replication-factor": "1", + // Store-gateway. + "-store-gateway.sharding-enabled": "false", + // alert manager + "-alertmanager.web.external-url": "http://localhost/alertmanager", + }, + ) + // make alert manager config dir + require.NoError(t, writeFileToSharedDir(s, "alertmanager_configs", []byte{})) + + path1 := path.Join(s.SharedDir(), "cortex-1") + path2 := path.Join(s.SharedDir(), "cortex-2") + + flags1 := mergeFlags(flags, map[string]string{"-blocks-storage.filesystem.dir": path1}) + + // Start first Cortex replicas + distributor := e2ecortex.NewDistributor("distributor", e2ecortex.RingStoreConsul, consul1.NetworkHTTPEndpoint(), flags1, "") + ingester := e2ecortex.NewIngester("ingester", e2ecortex.RingStoreConsul, consul1.NetworkHTTPEndpoint(), flags1, "") + queryScheduler := e2ecortex.NewQueryScheduler("query-scheduler", flags1, "") + storeGateway := e2ecortex.NewStoreGateway("store-gateway", e2ecortex.RingStoreConsul, consul1.NetworkHTTPEndpoint(), flags1, "") + require.NoError(t, s.StartAndWaitReady(queryScheduler, distributor, ingester, storeGateway)) + flags1 = mergeFlags(flags1, map[string]string{ + "-querier.store-gateway-addresses": strings.Join([]string{storeGateway.NetworkGRPCEndpoint()}, ","), + }) + queryFrontend := e2ecortex.NewQueryFrontend("query-frontend", mergeFlags(flags1, map[string]string{ + "-frontend.scheduler-address": queryScheduler.NetworkGRPCEndpoint(), + }), "") + require.NoError(t, s.Start(queryFrontend)) + querier := e2ecortex.NewQuerier("querier", e2ecortex.RingStoreConsul, consul1.NetworkHTTPEndpoint(), mergeFlags(flags1, map[string]string{ + "-querier.scheduler-address": queryScheduler.NetworkGRPCEndpoint(), + }), "") + require.NoError(t, s.StartAndWaitReady(querier)) + require.NoError(t, distributor.WaitSumMetrics(e2e.Equals(512), "cortex_ring_tokens_total")) + require.NoError(t, querier.WaitSumMetrics(e2e.Equals(512), "cortex_ring_tokens_total")) + c1, err := e2ecortex.NewClient(distributor.HTTPEndpoint(), queryFrontend.HTTPEndpoint(), "", "", "user-1") + require.NoError(t, err) + + // Enable distributed execution for the second Cortex instance. + flags2 := mergeFlags(flags, map[string]string{ + "-frontend.query-vertical-shard-size": "2", + "-blocks-storage.filesystem.dir": path2, + "-consul.hostname": consul2.NetworkHTTPEndpoint(), + "-querier.thanos-engine": "true", + "-querier.distributed-exec-enabled": "true", + "-api.querier-default-codec": "protobuf", + }) + + distributor2 := e2ecortex.NewDistributor("distributor2", e2ecortex.RingStoreConsul, consul2.NetworkHTTPEndpoint(), flags2, "") + ingester2 := e2ecortex.NewIngester("ingester2", e2ecortex.RingStoreConsul, consul2.NetworkHTTPEndpoint(), flags2, "") + queryScheduler2 := e2ecortex.NewQueryScheduler("query-scheduler2", flags2, "") + storeGateway2 := e2ecortex.NewStoreGateway("store-gateway2", e2ecortex.RingStoreConsul, consul2.NetworkHTTPEndpoint(), flags2, "") + require.NoError(t, s.StartAndWaitReady(queryScheduler2, distributor2, ingester2, storeGateway2)) + flags2 = mergeFlags(flags1, map[string]string{ + "-querier.store-gateway-addresses": strings.Join([]string{storeGateway2.NetworkGRPCEndpoint()}, ","), + }) + queryFrontend2 := e2ecortex.NewQueryFrontend("query-frontend2", mergeFlags(flags2, map[string]string{ + "-frontend.scheduler-address": queryScheduler2.NetworkGRPCEndpoint(), + }), "") + require.NoError(t, s.Start(queryFrontend2)) + querier2 := e2ecortex.NewQuerier("querier2", e2ecortex.RingStoreConsul, consul2.NetworkHTTPEndpoint(), mergeFlags(flags2, map[string]string{ + "-querier.scheduler-address": queryScheduler2.NetworkGRPCEndpoint(), + }), "") + require.NoError(t, s.StartAndWaitReady(querier2)) + require.NoError(t, distributor2.WaitSumMetrics(e2e.Equals(512), "cortex_ring_tokens_total")) + require.NoError(t, querier2.WaitSumMetrics(e2e.Equals(512), "cortex_ring_tokens_total")) + c2, err := e2ecortex.NewClient(distributor2.HTTPEndpoint(), queryFrontend2.HTTPEndpoint(), "", "", "user-1") + require.NoError(t, err) + + now := time.Now() + // Push some series to Cortex. + start := now.Add(-time.Minute * 10) + end := now.Add(-time.Minute * 1) + numSeries := 3 + numSamples := 20 + lbls := make([]labels.Labels, numSeries*2) + serieses := make([]prompb.TimeSeries, numSeries*2) + scrapeInterval := 30 * time.Second + for i := 0; i < numSeries; i++ { + series := e2e.GenerateSeriesWithSamples("test_series_a", start, scrapeInterval, i*numSamples, numSamples, prompb.Label{Name: "job", Value: "test"}, prompb.Label{Name: "series", Value: strconv.Itoa(i)}) + serieses[i] = series + builder := labels.NewBuilder(labels.EmptyLabels()) + for _, lbl := range series.Labels { + builder.Set(lbl.Name, lbl.Value) + } + lbls[i] = builder.Labels() + } + + // Generate another set of series for testing binary expression and vector matching. + for i := numSeries; i < 2*numSeries; i++ { + prompbLabels := []prompb.Label{{Name: "job", Value: "test"}, {Name: "series", Value: strconv.Itoa(i)}} + switch i % 3 { + case 0: + prompbLabels = append(prompbLabels, prompb.Label{Name: "status_code", Value: "200"}) + case 1: + prompbLabels = append(prompbLabels, prompb.Label{Name: "status_code", Value: "400"}) + default: + prompbLabels = append(prompbLabels, prompb.Label{Name: "status_code", Value: "500"}) + } + series := e2e.GenerateSeriesWithSamples("test_series_b", start, scrapeInterval, i*numSamples, numSamples, prompbLabels...) + serieses[i] = series + builder := labels.NewBuilder(labels.EmptyLabels()) + for _, lbl := range series.Labels { + builder.Set(lbl.Name, lbl.Value) + } + lbls[i] = builder.Labels() + } + res, err := c1.Push(serieses) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + res, err = c2.Push(serieses) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + + waitUntilReady(t, context.Background(), c1, c2, `{job="test"}`, start, end) + + rnd := rand.New(rand.NewSource(now.Unix())) + opts := []promqlsmith.Option{ + promqlsmith.WithEnableOffset(true), + promqlsmith.WithEnableAtModifier(true), + promqlsmith.WithEnabledFunctions(enabledFunctions), + promqlsmith.WithEnabledAggrs(enabledAggrs), + } + ps := promqlsmith.New(rnd, lbls, opts...) + + runQueryFuzzTestCases(t, ps, c1, c2, end, start, end, scrapeInterval, 1000, false) +} diff --git a/pkg/api/api.go b/pkg/api/api.go index ebe64440f9c..20f8f781bd3 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -23,6 +23,8 @@ import ( "github.com/cortexproject/cortex/pkg/alertmanager/alertmanagerpb" "github.com/cortexproject/cortex/pkg/compactor" "github.com/cortexproject/cortex/pkg/cortexpb" + "github.com/cortexproject/cortex/pkg/distributed_execution" + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" "github.com/cortexproject/cortex/pkg/distributor" "github.com/cortexproject/cortex/pkg/distributor/distributorpb" frontendv1 "github.com/cortexproject/cortex/pkg/frontend/v1" @@ -482,6 +484,10 @@ func (a *API) RegisterQueryScheduler(f *scheduler.Scheduler) { schedulerpb.RegisterSchedulerForQuerierServer(a.server.GRPC, f) } +func (a *API) RegisterQuerierServer(f *distributed_execution.QuerierServer) { + querierpb.RegisterQuerierServer(a.server.GRPC, f) +} + // RegisterServiceMapHandler registers the Cortex structs service handler // TODO: Refactor this code to be accomplished using the services.ServiceManager // or a future module manager #2291 diff --git a/pkg/api/handlers.go b/pkg/api/handlers.go index 54a55318542..0e831724e4a 100644 --- a/pkg/api/handlers.go +++ b/pkg/api/handlers.go @@ -25,10 +25,12 @@ import ( "github.com/weaveworks/common/middleware" "github.com/cortexproject/cortex/pkg/api/queryapi" + "github.com/cortexproject/cortex/pkg/distributed_execution" "github.com/cortexproject/cortex/pkg/engine" "github.com/cortexproject/cortex/pkg/querier" "github.com/cortexproject/cortex/pkg/querier/codec" "github.com/cortexproject/cortex/pkg/querier/stats" + "github.com/cortexproject/cortex/pkg/ring/client" "github.com/cortexproject/cortex/pkg/util" util_log "github.com/cortexproject/cortex/pkg/util/log" ) @@ -168,6 +170,9 @@ func NewQuerierHandler( metadataQuerier querier.MetadataQuerier, reg prometheus.Registerer, logger log.Logger, + queryTracker *distributed_execution.QueryTracker, + distributedExecEnabled bool, + querierClientPool *client.Pool, ) http.Handler { // Prometheus histograms for requests to the querier. querierRequestDuration := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ @@ -284,7 +289,7 @@ func NewQuerierHandler( legacyPromRouter := route.New().WithPrefix(path.Join(legacyPrefix, "/api/v1")) api.Register(legacyPromRouter) - queryAPI := queryapi.NewQueryAPI(engine, translateSampleAndChunkQueryable, statsRenderer, logger, codecs, corsOrigin) + queryAPI := queryapi.NewQueryAPI(engine, translateSampleAndChunkQueryable, statsRenderer, logger, codecs, corsOrigin, queryTracker, distributedExecEnabled, querierClientPool) // TODO(gotjosh): This custom handler is temporary until we're able to vendor the changes in: // https://github.com/prometheus/prometheus/pull/7125/files diff --git a/pkg/api/handlers_test.go b/pkg/api/handlers_test.go index 9b8b7930683..1717689152a 100644 --- a/pkg/api/handlers_test.go +++ b/pkg/api/handlers_test.go @@ -235,7 +235,7 @@ func TestBuildInfoAPI(t *testing.T) { version.Version = tc.version version.Branch = tc.branch version.Revision = tc.revision - handler := NewQuerierHandler(cfg, querierConfig, nil, nil, nil, nil, nil, &FakeLogger{}) + handler := NewQuerierHandler(cfg, querierConfig, nil, nil, nil, nil, nil, &FakeLogger{}, nil, false, nil) writer := httptest.NewRecorder() req := httptest.NewRequest("GET", "/api/v1/status/buildinfo", nil) req = req.WithContext(user.InjectOrgID(req.Context(), "test")) diff --git a/pkg/api/queryapi/query_api.go b/pkg/api/queryapi/query_api.go index ef9ef4e2801..4bd79e54767 100644 --- a/pkg/api/queryapi/query_api.go +++ b/pkg/api/queryapi/query_api.go @@ -16,23 +16,28 @@ import ( "github.com/prometheus/prometheus/util/annotations" "github.com/prometheus/prometheus/util/httputil" v1 "github.com/prometheus/prometheus/web/api/v1" + "github.com/thanos-io/promql-engine/logicalplan" "github.com/weaveworks/common/httpgrpc" "github.com/cortexproject/cortex/pkg/distributed_execution" "github.com/cortexproject/cortex/pkg/engine" "github.com/cortexproject/cortex/pkg/querier" + "github.com/cortexproject/cortex/pkg/ring/client" "github.com/cortexproject/cortex/pkg/util" "github.com/cortexproject/cortex/pkg/util/api" ) type QueryAPI struct { - queryable storage.SampleAndChunkQueryable - queryEngine engine.QueryEngine - now func() time.Time - statsRenderer v1.StatsRenderer - logger log.Logger - codecs []v1.Codec - CORSOrigin *regexp.Regexp + queryable storage.SampleAndChunkQueryable + queryEngine engine.QueryEngine + now func() time.Time + statsRenderer v1.StatsRenderer + logger log.Logger + codecs []v1.Codec + CORSOrigin *regexp.Regexp + queryTracker *distributed_execution.QueryTracker + distributedExecEnabled bool + querierClientPool *client.Pool } func NewQueryAPI( @@ -42,15 +47,21 @@ func NewQueryAPI( logger log.Logger, codecs []v1.Codec, CORSOrigin *regexp.Regexp, + queryTracker *distributed_execution.QueryTracker, + distributedExecEnabled bool, + querierClientPool *client.Pool, ) *QueryAPI { return &QueryAPI{ - queryEngine: qe, - queryable: q, - statsRenderer: statsRenderer, - logger: logger, - codecs: codecs, - CORSOrigin: CORSOrigin, - now: time.Now, + queryEngine: qe, + queryable: q, + statsRenderer: statsRenderer, + logger: logger, + codecs: codecs, + CORSOrigin: CORSOrigin, + now: time.Now, + queryTracker: queryTracker, + distributedExecEnabled: distributedExecEnabled, + querierClientPool: querierClientPool, } } @@ -108,14 +119,42 @@ func (q *QueryAPI) RangeQueryHandler(r *http.Request) (result apiFuncResult) { endTime := convertMsToTime(end) stepDuration := convertMsToDuration(step) + var isRoot bool + var queryID, fragmentID uint64 + if q.distributedExecEnabled { + isRoot, queryID, fragmentID, _, _ = distributed_execution.ExtractFragmentMetaData(ctx) + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.InitWriting(key) + } + } + byteLP := []byte(r.PostFormValue("plan")) - if len(byteLP) != 0 { - logicalPlan, err := distributed_execution.Unmarshal(byteLP) + if len(byteLP) != 0 && q.distributedExecEnabled { + logicalPlanNode, err := distributed_execution.Unmarshal(byteLP) if err != nil { + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.SetError(key) + } return apiFuncResult{nil, &apiError{errorInternal, fmt.Errorf("invalid logical plan: %v", err)}, nil, nil} } - qry, err = q.queryEngine.MakeRangeQueryFromPlan(ctx, q.queryable, opts, logicalPlan, startTime, endTime, stepDuration, r.FormValue("query")) + logicalplan.TraverseBottomUp(nil, &logicalPlanNode, func(parent, current *logicalplan.Node) bool { + if (*current).Type() == distributed_execution.RemoteNode { + remote, ok := (*current).(*distributed_execution.Remote) + if ok { + remote.InsertClientPool(q.querierClientPool) + } + } + return false + }) + + qry, err = q.queryEngine.MakeRangeQueryFromPlan(ctx, q.queryable, opts, logicalPlanNode, startTime, endTime, stepDuration, r.FormValue("query")) if err != nil { + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.SetError(key) + } return apiFuncResult{nil, &apiError{errorInternal, fmt.Errorf("failed to create range query from logical plan: %v", err)}, nil, nil} } } else { // if there is logical plan field is empty, fall back @@ -136,6 +175,14 @@ func (q *QueryAPI) RangeQueryHandler(r *http.Request) (result apiFuncResult) { ctx = httputil.ContextFromRequest(ctx, r) + if q.distributedExecEnabled { + isRoot, queryID, fragmentID, _, _ := distributed_execution.ExtractFragmentMetaData(ctx) + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.InitWriting(key) + } + } + res := qry.Exec(ctx) if res.Err != nil { return apiFuncResult{nil, returnAPIError(res.Err), res.Warnings, qry.Close} @@ -181,14 +228,45 @@ func (q *QueryAPI) InstantQueryHandler(r *http.Request) (result apiFuncResult) { var qry promql.Query tsTime := convertMsToTime(ts) - byteLP := []byte(r.PostFormValue("plan")) - if len(byteLP) != 0 { - logicalPlan, err := distributed_execution.Unmarshal(byteLP) + var isRoot bool + var queryID, fragmentID uint64 + if q.distributedExecEnabled { + isRoot, queryID, fragmentID, _, _ = distributed_execution.ExtractFragmentMetaData(ctx) + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.InitWriting(key) + } + } + + byteLogicalPlan := []byte(r.PostFormValue("plan")) + if len(byteLogicalPlan) != 0 && q.distributedExecEnabled { + logicalPlanNode, err := distributed_execution.Unmarshal(byteLogicalPlan) if err != nil { + if q.distributedExecEnabled { + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.SetError(key) + } + } return apiFuncResult{nil, &apiError{errorInternal, fmt.Errorf("invalid logical plan: %v", err)}, nil, nil} } - qry, err = q.queryEngine.MakeInstantQueryFromPlan(ctx, q.queryable, opts, logicalPlan, tsTime, r.FormValue("query")) + + logicalplan.TraverseBottomUp(nil, &logicalPlanNode, func(parent, current *logicalplan.Node) bool { + if (*current).Type() == distributed_execution.RemoteNode { + remote, ok := (*current).(*distributed_execution.Remote) + if ok { + remote.InsertClientPool(q.querierClientPool) + } + } + return false + }) + + qry, err = q.queryEngine.MakeInstantQueryFromPlan(ctx, q.queryable, opts, logicalPlanNode, tsTime, r.FormValue("query")) if err != nil { + if !isRoot { + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + q.queryTracker.SetError(key) + } return apiFuncResult{nil, &apiError{errorInternal, fmt.Errorf("failed to create instant query from logical plan: %v", err)}, nil, nil} } } else { // if there is logical plan field is empty, fall back @@ -239,6 +317,19 @@ func (q *QueryAPI) Wrap(f apiFunc) http.HandlerFunc { } if result.data != nil { + ctx := httputil.ContextFromRequest(r.Context(), r) + + if q.distributedExecEnabled { + isRoot, queryID, fragmentID, _, _ := distributed_execution.ExtractFragmentMetaData(ctx) + key := distributed_execution.MakeFragmentKey(queryID, fragmentID) + + q.queryTracker.SetComplete(key, result.data) + + if isRoot { + q.respond(w, r, result.data, result.warnings, r.FormValue("query")) + } + return + } q.respond(w, r, result.data, result.warnings, r.FormValue("query")) return } diff --git a/pkg/api/queryapi/query_api_test.go b/pkg/api/queryapi/query_api_test.go index 2a0ce0cbc99..eaebda0b465 100644 --- a/pkg/api/queryapi/query_api_test.go +++ b/pkg/api/queryapi/query_api_test.go @@ -28,6 +28,7 @@ import ( "github.com/thanos-io/promql-engine/query" "github.com/weaveworks/common/user" + "github.com/cortexproject/cortex/pkg/distributed_execution" engine2 "github.com/cortexproject/cortex/pkg/engine" "github.com/cortexproject/cortex/pkg/querier" "github.com/cortexproject/cortex/pkg/querier/series" @@ -183,7 +184,7 @@ func Test_CustomAPI(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - c := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*")) + c := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*"), nil, false, nil) router := mux.NewRouter() router.Path("/api/v1/query").Methods("POST").Handler(c.Wrap(c.InstantQueryHandler)) @@ -244,7 +245,7 @@ func Test_InvalidCodec(t *testing.T) { }, } - queryAPI := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{&mockCodec{}}, regexp.MustCompile(".*")) + queryAPI := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{&mockCodec{}}, regexp.MustCompile(".*"), nil, false, nil) router := mux.NewRouter() router.Path("/api/v1/query").Methods("POST").Handler(queryAPI.Wrap(queryAPI.InstantQueryHandler)) @@ -285,7 +286,7 @@ func Test_CustomAPI_StatsRenderer(t *testing.T) { }, } - queryAPI := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*")) + queryAPI := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*"), nil, false, nil) router := mux.NewRouter() router.Path("/api/v1/query_range").Methods("POST").Handler(queryAPI.Wrap(queryAPI.RangeQueryHandler)) @@ -305,7 +306,10 @@ func Test_CustomAPI_StatsRenderer(t *testing.T) { require.Equal(t, uint64(4), queryStats.LoadScannedSamples()) } -func Test_Logicalplan_Requests(t *testing.T) { +// Test_Logicalplan_SimpleRequests verifies basic logical plan execution without fragmentation. +// Uses simple queries to test error handling and execution flow, avoiding distributed +// execution scenarios to focus on core functionality. +func Test_Logicalplan_SimpleRequests(t *testing.T) { engine := engine2.New( promql.EngineOpts{ MaxSamples: 100, @@ -342,7 +346,7 @@ func Test_Logicalplan_Requests(t *testing.T) { expectedBody string }{ { - name: "[Range Query] with valid logical plan and empty query string", + name: "[Range Query] with valid simple logical plan and empty query string", path: "/api/v1/query_range?end=1536673680&query=&start=1536673665&step=5", start: 1536673665, end: 1536673680, @@ -360,10 +364,10 @@ func Test_Logicalplan_Requests(t *testing.T) { end: 1536673680, stepDuration: 5, requestBody: func(t *testing.T) []byte { - return append(createTestLogicalPlan(t, 1536673665, 1536673680, 5), []byte("random data")...) + return []byte("random_data") }, expectedCode: http.StatusInternalServerError, - expectedBody: `{"status":"error","errorType":"server_error","error":"invalid logical plan: invalid character 'r' after top-level value"}`, + expectedBody: `{"status":"error","errorType":"server_error","error":"invalid logical plan: invalid character 'r' looking for beginning of value"}`, }, { name: "[Range Query] with empty body and non-empty query string", // fall back to promql query execution @@ -408,10 +412,10 @@ func Test_Logicalplan_Requests(t *testing.T) { end: 1536673670, stepDuration: 0, requestBody: func(t *testing.T) []byte { - return append(createTestLogicalPlan(t, 1536673670, 1536673670, 0), []byte("random data")...) + return []byte("random_data") }, expectedCode: http.StatusInternalServerError, - expectedBody: `{"status":"error","errorType":"server_error","error":"invalid logical plan: invalid character 'r' after top-level value"}`, + expectedBody: `{"status":"error","errorType":"server_error","error":"invalid logical plan: invalid character 'r' looking for beginning of value"}`, }, { name: "[Instant Query] with empty body and non-empty query string", @@ -441,12 +445,16 @@ func Test_Logicalplan_Requests(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*")) + tracker := distributed_execution.NewQueryTracker() + c := NewQueryAPI(engine, mockQueryable, querier.StatsRenderer, log.NewNopLogger(), []v1.Codec{v1.JSONCodec{}}, regexp.MustCompile(".*"), tracker, true, nil) router := mux.NewRouter() router.Path("/api/v1/query").Methods("POST").Handler(c.Wrap(c.InstantQueryHandler)) router.Path("/api/v1/query_range").Methods("POST").Handler(c.Wrap(c.RangeQueryHandler)) req := createTestRequest(tt.path, tt.requestBody(t)) + newctx := distributed_execution.InjectFragmentMetaData(context.Background(), uint64(1), uint64(1), true, nil) + req = req.WithContext(newctx) + rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -498,8 +506,8 @@ func createTestLogicalPlan(t *testing.T, start, end int64, stepDuration int64) [ logicalPlan, err := logicalplan.NewFromAST(expr, &qOpts, planOpts) require.NoError(t, err) - byteval, err := logicalplan.Marshal(logicalPlan.Root()) + bodyBytes, err := logicalplan.Marshal(logicalPlan.Root()) require.NoError(t, err) - return byteval + return bodyBytes } diff --git a/pkg/cortex/modules.go b/pkg/cortex/modules.go index 013dbb90834..c357a707e80 100644 --- a/pkg/cortex/modules.go +++ b/pkg/cortex/modules.go @@ -8,6 +8,7 @@ import ( "net/http" "runtime" "runtime/debug" + "time" "github.com/go-kit/log/level" "github.com/opentracing-contrib/go-stdlib/nethttp" @@ -29,6 +30,7 @@ import ( "github.com/cortexproject/cortex/pkg/compactor" configAPI "github.com/cortexproject/cortex/pkg/configs/api" "github.com/cortexproject/cortex/pkg/configs/db" + "github.com/cortexproject/cortex/pkg/distributed_execution" "github.com/cortexproject/cortex/pkg/distributor" "github.com/cortexproject/cortex/pkg/engine" "github.com/cortexproject/cortex/pkg/flusher" @@ -45,6 +47,7 @@ import ( querier_worker "github.com/cortexproject/cortex/pkg/querier/worker" cortexquerysharding "github.com/cortexproject/cortex/pkg/querysharding" "github.com/cortexproject/cortex/pkg/ring" + "github.com/cortexproject/cortex/pkg/ring/client" "github.com/cortexproject/cortex/pkg/ring/kv/codec" "github.com/cortexproject/cortex/pkg/ring/kv/memberlist" "github.com/cortexproject/cortex/pkg/ruler" @@ -360,6 +363,29 @@ func (t *Cortex) initTenantFederation() (serv services.Service, err error) { // │ │ // └──────────────────┘ func (t *Cortex) initQuerier() (serv services.Service, err error) { + + // Create new map for caching partial results during distributed execution + var queryTracker *distributed_execution.QueryTracker + var querierServer *distributed_execution.QuerierServer + var querierClientPool *client.Pool + + if t.Cfg.Querier.DistributedExecEnabled { + // set up querier server service and register it + queryTracker = distributed_execution.NewQueryTracker() + querierServer = distributed_execution.NewQuerierServer(queryTracker) + querierClientPool = distributed_execution.NewQuerierPool(t.Cfg.Worker.GRPCClientConfig, prometheus.DefaultRegisterer, util_log.Logger) + + // automatically clean query tracker every time interval + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + queryTracker.CleanExpired() + } + }() + t.API.RegisterQuerierServer(querierServer) + } + // Create a internal HTTP handler that is configured with the Prometheus API routes and points // to a Prometheus API struct instantiated with the Cortex Queryable. internalQuerierRouter := api.NewQuerierHandler( @@ -371,6 +397,9 @@ func (t *Cortex) initQuerier() (serv services.Service, err error) { t.MetadataQuerier, prometheus.DefaultRegisterer, util_log.Logger, + queryTracker, + t.Cfg.Querier.DistributedExecEnabled, + querierClientPool, ) // If the querier is running standalone without the query-frontend or query-scheduler, we must register it's internal diff --git a/pkg/distributed_execution/id.go b/pkg/distributed_execution/id.go new file mode 100644 index 00000000000..e7fad1b4773 --- /dev/null +++ b/pkg/distributed_execution/id.go @@ -0,0 +1,32 @@ +package distributed_execution + +import ( + "context" +) + +type fragmentMetadataKey struct{} + +type fragmentMetadata struct { + queryID uint64 + fragmentID uint64 + childIDToAddr map[uint64]string + isRoot bool +} + +func InjectFragmentMetaData(ctx context.Context, fragmentID uint64, queryID uint64, isRoot bool, childIDToAddr map[uint64]string) context.Context { + + return context.WithValue(ctx, fragmentMetadataKey{}, fragmentMetadata{ + queryID: queryID, + fragmentID: fragmentID, + childIDToAddr: childIDToAddr, + isRoot: isRoot, + }) +} + +func ExtractFragmentMetaData(ctx context.Context) (isRoot bool, queryID uint64, fragmentID uint64, childAddrs map[uint64]string, ok bool) { + metadata, ok := ctx.Value(fragmentMetadataKey{}).(fragmentMetadata) + if !ok { + return false, 0, 0, nil, false + } + return metadata.isRoot, metadata.queryID, metadata.fragmentID, metadata.childIDToAddr, true +} diff --git a/pkg/distributed_execution/id_test.go b/pkg/distributed_execution/id_test.go new file mode 100644 index 00000000000..855a70be5dd --- /dev/null +++ b/pkg/distributed_execution/id_test.go @@ -0,0 +1,97 @@ +package distributed_execution + +import ( + "context" + "reflect" + "testing" +) + +func TestFragmentMetadata(t *testing.T) { + tests := []struct { + name string + queryID uint64 + fragID uint64 + isRoot bool + childIDs []uint64 + childAddr []string + }{ + { + name: "basic test", + queryID: 123, + fragID: 456, + isRoot: true, + childIDs: []uint64{1, 2, 3}, + childAddr: []string{"addr1", "addr2", "addr3"}, + }, + { + name: "empty children", + queryID: 789, + fragID: 101, + isRoot: false, + childIDs: []uint64{}, + childAddr: []string{}, + }, + { + name: "single child", + queryID: 999, + fragID: 888, + isRoot: true, + childIDs: []uint64{42}, + childAddr: []string{"[IP_ADDRESS]:8080"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // injection + ctx := context.Background() + + childIDToAddr := make(map[uint64]string) + for i, childID := range tt.childIDs { + childIDToAddr[childID] = tt.childAddr[i] + } + newCtx := InjectFragmentMetaData(ctx, tt.fragID, tt.queryID, tt.isRoot, childIDToAddr) + + // extraction + isRoot, queryID, fragmentID, childAddrs, ok := ExtractFragmentMetaData(newCtx) + + // verify results + if !ok { + t.Error("ExtractFragmentMetaData failed, ok = false") + } + + if isRoot != tt.isRoot { + t.Errorf("isRoot = %v, want %v", isRoot, tt.isRoot) + } + + if queryID != tt.queryID { + t.Errorf("queryID = %v, want %v", queryID, tt.queryID) + } + + if fragmentID != tt.fragID { + t.Errorf("fragmentID = %v, want %v", fragmentID, tt.fragID) + } + + // create expected childIDToAddr map + expectedChildAddrs := make(map[uint64]string) + for i, childID := range tt.childIDs { + expectedChildAddrs[childID] = tt.childAddr[i] + } + + if !reflect.DeepEqual(childAddrs, expectedChildAddrs) { + t.Errorf("childAddrs = %v, want %v", childAddrs, expectedChildAddrs) + } + }) + } +} + +func TestExtractFragmentMetaDataWithEmptyContext(t *testing.T) { + ctx := context.Background() + isRoot, queryID, fragmentID, childAddrs, ok := ExtractFragmentMetaData(ctx) + if ok { + t.Error("ExtractFragmentMetaData should return ok=false for empty context") + } + if isRoot || queryID != 0 || fragmentID != 0 || childAddrs != nil { + t.Error("ExtractFragmentMetaData should return zero values for empty context") + } +} diff --git a/pkg/distributed_execution/querier_service_client.go b/pkg/distributed_execution/querier_service_client.go new file mode 100644 index 00000000000..d622fff45d5 --- /dev/null +++ b/pkg/distributed_execution/querier_service_client.go @@ -0,0 +1,129 @@ +package distributed_execution + +import ( + "time" + + "github.com/go-kit/log" + otgrpc "github.com/opentracing-contrib/go-grpc" + "github.com/opentracing/opentracing-go" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/prometheus/model/histogram" + "github.com/weaveworks/common/middleware" + "google.golang.org/grpc" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/cortexproject/cortex/pkg/cortexpb" + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" + "github.com/cortexproject/cortex/pkg/ring/client" + "github.com/cortexproject/cortex/pkg/util/grpcclient" + cortexmiddleware "github.com/cortexproject/cortex/pkg/util/middleware" +) + +type querierClient struct { + querierpb.QuerierClient + grpc_health_v1.HealthClient + conn *grpc.ClientConn +} + +func (qc *querierClient) Close() error { + return qc.conn.Close() +} + +func NewQuerierPool(cfg grpcclient.Config, reg prometheus.Registerer, log log.Logger) *client.Pool { + requestDuration := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Name: "cortex_querier_query_request_duration_seconds", + Help: "Time spent doing requests to querier.", + Buckets: prometheus.ExponentialBuckets(0.001, 4, 6), + }, []string{"operation", "status_code"}) + + clientsGauge := promauto.With(reg).NewGauge(prometheus.GaugeOpts{ + Namespace: "cortex", + Name: "cortex_querier_query_clients", + Help: "TThe current number of clients connected to querier.", + ConstLabels: map[string]string{"client": "querier"}, + }) + + poolConfig := client.PoolConfig{ + CheckInterval: time.Minute, + HealthCheckEnabled: true, + HealthCheckTimeout: 10 * time.Second, + } + + q := &querierPool{ + grpcConfig: cfg, + requestDuration: requestDuration, + } + + return client.NewPool("querier", poolConfig, nil, q.createQuerierClient, clientsGauge, log) +} + +type querierPool struct { + grpcConfig grpcclient.Config + requestDuration *prometheus.HistogramVec +} + +func (q *querierPool) createQuerierClient(addr string) (client.PoolClient, error) { + + opts, err := q.grpcConfig.DialOption([]grpc.UnaryClientInterceptor{ + otgrpc.OpenTracingClientInterceptor(opentracing.GlobalTracer()), + middleware.ClientUserHeaderInterceptor, + cortexmiddleware.PrometheusGRPCUnaryInstrumentation(q.requestDuration), + }, []grpc.StreamClientInterceptor{ + otgrpc.OpenTracingStreamClientInterceptor(opentracing.GlobalTracer()), + middleware.StreamClientUserHeaderInterceptor, + cortexmiddleware.PrometheusGRPCStreamInstrumentation(q.requestDuration), + }) + + if err != nil { + return nil, err + } + + conn, err := grpc.NewClient(addr, opts...) + if err != nil { + return nil, err + } + + return &querierClient{ + QuerierClient: querierpb.NewQuerierClient(conn), + HealthClient: grpc_health_v1.NewHealthClient(conn), + conn: conn, + }, nil +} + +func floatHistogramProtoToFloatHistograms(hps []cortexpb.Histogram) []*histogram.FloatHistogram { + floatHistograms := make([]*histogram.FloatHistogram, len(hps)) + for _, hp := range hps { + newHist := floatHistogramProtoToFloatHistogram(hp) + floatHistograms = append(floatHistograms, newHist) + } + return floatHistograms +} + +func floatHistogramProtoToFloatHistogram(hp cortexpb.Histogram) *histogram.FloatHistogram { + _, IsFloatHist := hp.GetCount().(*cortexpb.Histogram_CountFloat) + if !IsFloatHist { + panic("FloatHistogramProtoToFloatHistogram called with an integer histogram") + } + return &histogram.FloatHistogram{ + CounterResetHint: histogram.CounterResetHint(hp.ResetHint), + Schema: hp.Schema, + ZeroThreshold: hp.ZeroThreshold, + ZeroCount: hp.GetZeroCountFloat(), + Count: hp.GetCountFloat(), + Sum: hp.Sum, + PositiveSpans: spansProtoToSpans(hp.GetPositiveSpans()), + PositiveBuckets: hp.GetPositiveCounts(), + NegativeSpans: spansProtoToSpans(hp.GetNegativeSpans()), + NegativeBuckets: hp.GetNegativeCounts(), + } +} + +func spansProtoToSpans(s []cortexpb.BucketSpan) []histogram.Span { + spans := make([]histogram.Span, len(s)) + for i := 0; i < len(s); i++ { + spans[i] = histogram.Span{Offset: s[i].Offset, Length: s[i].Length} + } + + return spans +} diff --git a/pkg/distributed_execution/querier_service_client_test.go b/pkg/distributed_execution/querier_service_client_test.go new file mode 100644 index 00000000000..9796af38196 --- /dev/null +++ b/pkg/distributed_execution/querier_service_client_test.go @@ -0,0 +1,317 @@ +package distributed_execution + +import ( + "context" + "io" + "testing" + + "github.com/go-kit/log" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thanos-io/promql-engine/execution/model" + "github.com/weaveworks/common/server" + "google.golang.org/grpc" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" + "github.com/cortexproject/cortex/pkg/ring/client" + "github.com/cortexproject/cortex/pkg/util/grpcclient" +) + +// TestQuerierPool verifies that the querier service client pool correctly manages +// client connections by testing address addition and client retrieval functionality +func TestQuerierPool(t *testing.T) { + tests := []struct { + name string + poolSetup func() (*client.Pool, *mockServer) + test func(*testing.T, *client.Pool, *mockServer) + }{ + { + name: "pool creates and manages clients", + poolSetup: func() (*client.Pool, *mockServer) { + + mockServer := newMockServer(t) + + cfg := grpcclient.Config{ + MaxRecvMsgSize: 1024, + MaxSendMsgSize: 1024, + } + + reg := prometheus.NewRegistry() + logger := log.NewNopLogger() + + pool := NewQuerierPool(cfg, reg, logger) + + return pool, mockServer + }, + test: func(t *testing.T, pool *client.Pool, mockServer *mockServer) { + // test getting client + client, err := pool.GetClientFor(":8005") + assert.NoError(t, err) + assert.NotNil(t, client) + + // test client is reused + client2, err := pool.GetClientFor(":8005") + assert.NoError(t, err) + assert.Equal(t, client, client2) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool, mockServer := tt.poolSetup() + defer mockServer.Stop() + tt.test(t, pool, mockServer) + }) + } +} + +type mockQuerierServer struct { + querierpb.UnimplementedQuerierServer +} + +func (m *mockQuerierServer) Next(req *querierpb.NextRequest, stream querierpb.Querier_NextServer) error { + return nil +} + +func (m *mockQuerierServer) Series(req *querierpb.SeriesRequest, stream querierpb.Querier_SeriesServer) error { + return nil +} + +type mockHealthServer struct { + grpc_health_v1.UnimplementedHealthServer +} + +func (m *mockHealthServer) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) { + return &grpc_health_v1.HealthCheckResponse{ + Status: grpc_health_v1.HealthCheckResponse_SERVING, + }, nil +} + +type mockServer struct { + server *grpc.Server + addr int +} + +func newMockServer(t *testing.T) *mockServer { + serverCfg := server.Config{ + HTTPListenNetwork: server.DefaultNetwork, + LogSourceIPs: true, + MetricsNamespace: "with_source_ip_extractor", + } + server, err := server.New(serverCfg) + require.NoError(t, err) + + mockQuerier := &mockQuerierServer{} + querierpb.RegisterQuerierServer(server.GRPC, mockQuerier) + grpc_health_v1.RegisterHealthServer(server.GRPC, &mockHealthServer{}) + + return &mockServer{ + server: server.GRPC, + addr: serverCfg.GRPCListenPort, + } +} + +func (m *mockServer) Stop() { + if m.server != nil { + m.server.Stop() + } +} + +// TestClientBuffer verifies that the streaming buffer matches the configured batch size and maintains correct data ordering +func TestClientBuffer(t *testing.T) { + tests := []struct { + name string + bufferData []model.StepVector + batchSize int64 + numSteps int + expectedCalls [][]model.StepVector + wantErr bool + }{ + { + name: "buffer with multiple batches", + numSteps: 1, + bufferData: []model.StepVector{ + { + T: 1000, + SampleIDs: []uint64{1}, + Samples: []float64{10.0}, + }, + { + T: 2000, + SampleIDs: []uint64{1}, + Samples: []float64{20.0}, + }, + }, + batchSize: 1, + expectedCalls: [][]model.StepVector{ + { + { + T: 1000, + SampleIDs: []uint64{1}, + Samples: []float64{10.0}, + }, + }, + { + { + T: 2000, + SampleIDs: []uint64{1}, + Samples: []float64{20.0}, + }, + }, + }, + wantErr: false, + }, + { + name: "single batch full buffer", + numSteps: 1, + bufferData: []model.StepVector{ + { + T: 1000, + SampleIDs: []uint64{1, 2}, + Samples: []float64{10.0, 20.0}, + }, + }, + batchSize: 2, + expectedCalls: [][]model.StepVector{ + { + { + T: 1000, + SampleIDs: []uint64{1, 2}, + Samples: []float64{10.0, 20.0}, + }, + }, + }, + wantErr: false, + }, + { + name: "buffer with multiple batches", + numSteps: 2, + bufferData: []model.StepVector{ + { + T: 1000, + SampleIDs: []uint64{1}, + Samples: []float64{10.0}, + }, + { + T: 2000, + SampleIDs: []uint64{1}, + Samples: []float64{20.0}, + }, + { + T: 3000, + SampleIDs: []uint64{1}, + Samples: []float64{30.0}, + }, + { + T: 4000, + SampleIDs: []uint64{1}, + Samples: []float64{40.0}, + }, + { + T: 5000, + SampleIDs: []uint64{1}, + Samples: []float64{50.0}, + }, + }, + batchSize: 2, + expectedCalls: [][]model.StepVector{ + { + { + T: 1000, + SampleIDs: []uint64{1}, + Samples: []float64{10.0}, + }, + { + T: 2000, + SampleIDs: []uint64{1}, + Samples: []float64{20.0}, + }, + }, + { + { + T: 3000, + SampleIDs: []uint64{1}, + Samples: []float64{30.0}, + }, + { + T: 4000, + SampleIDs: []uint64{1}, + Samples: []float64{40.0}, + }, + }, + { + { + T: 5000, + SampleIDs: []uint64{1}, + Samples: []float64{50.0}, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exec := &DistributedRemoteExecution{ + mint: 0, + maxt: 5000, + step: 1000, + currentStep: 0, + numSteps: tt.numSteps, + batchSize: tt.batchSize, + buffer: tt.bufferData, + bufferIndex: 0, + initialized: true, + } + + mockClient := &mockQuerierClient{ + nextShouldBeCalled: false, + } + exec.client = mockClient + + ctx := context.Background() + + for i, expectedCall := range tt.expectedCalls { // next() server call + result, err := exec.Next(ctx) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, expectedCall, result, "call %d", i) + } + }) + } +} + +type mockQuerierClient struct { + nextShouldBeCalled bool +} + +func (m *mockQuerierClient) Series(ctx context.Context, req *querierpb.SeriesRequest, opts ...grpc.CallOption) (querierpb.Querier_SeriesClient, error) { + return &mockSeriesStream{}, nil +} + +func (m *mockQuerierClient) Next(ctx context.Context, req *querierpb.NextRequest, opts ...grpc.CallOption) (querierpb.Querier_NextClient, error) { + return &mockNextStream{}, nil +} + +type mockSeriesStream struct { + querierpb.Querier_SeriesClient +} + +func (m *mockSeriesStream) Recv() (*querierpb.SeriesBatch, error) { + return nil, io.EOF +} + +type mockNextStream struct { + querierpb.Querier_NextClient +} + +func (m *mockNextStream) Recv() (*querierpb.StepVectorBatch, error) { + return nil, io.EOF +} diff --git a/pkg/distributed_execution/querier_service_server.go b/pkg/distributed_execution/querier_service_server.go new file mode 100644 index 00000000000..59c9209ca93 --- /dev/null +++ b/pkg/distributed_execution/querier_service_server.go @@ -0,0 +1,311 @@ +package distributed_execution + +import ( + "fmt" + "time" + + "github.com/prometheus/prometheus/model/histogram" + "github.com/prometheus/prometheus/promql" + "github.com/prometheus/prometheus/promql/parser" + v1 "github.com/prometheus/prometheus/web/api/v1" + + "github.com/cortexproject/cortex/pkg/cortexpb" + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" +) + +const ( + BATCHSIZE = 1000 + WritingTimeout = 100 * time.Millisecond + MaxRetries = 3 + RetryDelay = 100 * time.Millisecond +) + +// QuerierServer handles streaming of partial query results from a querier to clients +// during distributed query execution. It maintains query state through a QueryTracker. +type QuerierServer struct { + queryTracker *QueryTracker +} + +// NewQuerierServer creates a new QuerierServer instance with the provided QueryTracker +// to manage query fragment results. +func NewQuerierServer(cache *QueryTracker) *QuerierServer { + return &QuerierServer{ + queryTracker: cache, + } +} + +// Series streams series metadata to the client, allowing discovery of the data shape +// before receiving actual values. This should be called before Next(). +func (s *QuerierServer) Series(req *querierpb.SeriesRequest, srv querierpb.Querier_SeriesServer) error { + key := MakeFragmentKey(req.QueryID, req.FragmentID) + + for { + var result FragmentResult + var ok bool + for attempt := 0; attempt < MaxRetries; attempt++ { + result, ok = s.queryTracker.Get(key) + if ok { + break + } + if attempt == MaxRetries { + return fmt.Errorf("fragment not found after %d attempts: %v", MaxRetries, key) + } + time.Sleep(RetryDelay) + } + + switch result.Status { + case StatusDone: + v1ResultData := result.Data.(*v1.QueryData) + + switch v1ResultData.ResultType { + case parser.ValueTypeMatrix: + series := v1ResultData.Result.(promql.Matrix) + + seriesBatch := []*querierpb.OneSeries{} + for _, s := range series { + oneSeries := &querierpb.OneSeries{ + Labels: make([]*querierpb.Label, s.Metric.Len()), + } + + j := 0 + for name, val := range s.Metric.Map() { + oneSeries.Labels[j] = &querierpb.Label{ + Name: name, + Value: val, + } + j++ + } + seriesBatch = append(seriesBatch, oneSeries) + } + if err := srv.Send(&querierpb.SeriesBatch{ + OneSeries: seriesBatch}); err != nil { + return err + } + + return nil + + case parser.ValueTypeVector: + samples := v1ResultData.Result.(promql.Vector) + + seriesBatch := []*querierpb.OneSeries{} + for _, s := range samples { + oneSeries := &querierpb.OneSeries{ + Labels: make([]*querierpb.Label, s.Metric.Len()), + } + + j := 0 + for name, val := range s.Metric.Map() { + oneSeries.Labels[j] = &querierpb.Label{ + Name: name, + Value: val, + } + j++ + } + seriesBatch = append(seriesBatch, oneSeries) + } + if err := srv.Send(&querierpb.SeriesBatch{ + OneSeries: seriesBatch, + }); err != nil { + return err + } + return nil + } + + case StatusError: + return fmt.Errorf("fragment processing failed") + + case StatusWriting: + time.Sleep(WritingTimeout) + continue + } + } +} + +// Next streams query result data to the client in batches. It should be called +// after Series() to receive the actual data values. +func (s *QuerierServer) Next(req *querierpb.NextRequest, srv querierpb.Querier_NextServer) error { + key := MakeFragmentKey(req.QueryID, req.FragmentID) + + batchSize := int(req.Batchsize) + if batchSize <= 0 { + batchSize = BATCHSIZE + } + + for { + var result FragmentResult + var ok bool + for attempt := 0; attempt < MaxRetries; attempt++ { + result, ok = s.queryTracker.Get(key) + if ok { + break + } + if attempt == MaxRetries { + return fmt.Errorf("fragment not found after %d attempts: %v", MaxRetries, key) + } + time.Sleep(RetryDelay) + } + + switch result.Status { + case StatusDone: + v1ResultData := result.Data.(*v1.QueryData) + + switch v1ResultData.ResultType { + case parser.ValueTypeMatrix: + matrix := v1ResultData.Result.(promql.Matrix) + + numTimeSteps := matrix.TotalSamples() + + for timeStep := 0; timeStep < numTimeSteps; timeStep += batchSize { + batch := &querierpb.StepVectorBatch{ + StepVectors: make([]*querierpb.StepVector, 0, len(matrix)), + } + for t := 0; t < batchSize; t++ { + for i, series := range matrix { + vector, err := s.createVectorForTimestep(&series, timeStep+t, uint64(i)) + if err != nil { + return err + } + batch.StepVectors = append(batch.StepVectors, vector) + } + } + if err := srv.Send(batch); err != nil { + return fmt.Errorf("error sending batch: %w", err) + } + } + return nil + + case parser.ValueTypeVector: + vector := v1ResultData.Result.(promql.Vector) + + for i := 0; i < len(vector); i += batchSize { + end := i + batchSize + if end > len(vector) { + end = len(vector) + } + + batch := &querierpb.StepVectorBatch{ + StepVectors: []*querierpb.StepVector{}, + } + + var timestamp int64 + sampleIDs := make([]uint64, 0, batchSize) + samples := make([]float64, 0, batchSize) + histogramIDs := make([]uint64, 0, batchSize) + histograms := make([]*histogram.FloatHistogram, 0, batchSize) + + for j, sample := range (vector)[i:end] { + if sample.H == nil { + sampleIDs = append(sampleIDs, uint64(j)) + samples = append(samples, sample.F) + } else { + histogramIDs = append(histogramIDs, uint64(j)) + histograms = append(histograms, sample.H) + } + } + vec := &querierpb.StepVector{ + T: timestamp, + Sample_IDs: sampleIDs, + Samples: samples, + Histogram_IDs: histogramIDs, + Histograms: floatHistogramsToFloatHistogramProto(histograms), + } + batch.StepVectors = append(batch.StepVectors, vec) + if err := srv.Send(batch); err != nil { + return err + } + } + return nil + + default: + return fmt.Errorf("unsupported result type: %v", v1ResultData.ResultType) + } + case StatusError: + return fmt.Errorf("fragment processing failed") + case StatusWriting: + time.Sleep(WritingTimeout) + continue + } + } +} + +func (s *QuerierServer) createVectorForTimestep(series *promql.Series, timeStep int, sampleID uint64) (*querierpb.StepVector, error) { + var samples []float64 + var sampleIDs []uint64 + var histograms []*histogram.FloatHistogram + var histogramIDs []uint64 + var timestamp int64 + + if timeStep < len(series.Floats) { + point := series.Floats[timeStep] + timestamp = point.T + samples = append(samples, point.F) + sampleIDs = append(sampleIDs, sampleID) + } + + if timeStep < len(series.Histograms) { + point := series.Histograms[timeStep] + timestamp = point.T + histograms = append(histograms, point.H) + histogramIDs = append(histogramIDs, uint64(timeStep)) + } + + return &querierpb.StepVector{ + T: timestamp, + Sample_IDs: sampleIDs, + Samples: samples, + Histogram_IDs: histogramIDs, + Histograms: floatHistogramsToFloatHistogramProto(histograms), + }, nil +} + +func floatHistogramsToFloatHistogramProto(histograms []*histogram.FloatHistogram) []cortexpb.Histogram { + if histograms == nil { + return []cortexpb.Histogram{} + } + + protoHistograms := make([]cortexpb.Histogram, 0, len(histograms)) + for _, h := range histograms { + if h != nil { + protoHist := floatHistogramToFloatHistogramProto(h) + protoHistograms = append(protoHistograms, *protoHist) + } + } + return protoHistograms +} + +func floatHistogramToFloatHistogramProto(h *histogram.FloatHistogram) *cortexpb.Histogram { + if h == nil { + return nil + } + + return &cortexpb.Histogram{ + ResetHint: cortexpb.Histogram_ResetHint(h.CounterResetHint), + Schema: h.Schema, + ZeroThreshold: h.ZeroThreshold, + Count: &cortexpb.Histogram_CountFloat{ + CountFloat: h.Count, + }, + ZeroCount: &cortexpb.Histogram_ZeroCountFloat{ + ZeroCountFloat: h.ZeroCount, + }, + Sum: h.Sum, + PositiveSpans: spansToSpansProto(h.PositiveSpans), + PositiveCounts: h.PositiveBuckets, + NegativeSpans: spansToSpansProto(h.NegativeSpans), + NegativeCounts: h.NegativeBuckets, + } +} + +func spansToSpansProto(spans []histogram.Span) []cortexpb.BucketSpan { + if spans == nil { + return nil + } + protoSpans := make([]cortexpb.BucketSpan, len(spans)) + for i, span := range spans { + protoSpans[i] = cortexpb.BucketSpan{ + Offset: span.Offset, + Length: span.Length, + } + } + return protoSpans +} diff --git a/pkg/distributed_execution/querier_service_server_test.go b/pkg/distributed_execution/querier_service_server_test.go new file mode 100644 index 00000000000..8dcb557d0a0 --- /dev/null +++ b/pkg/distributed_execution/querier_service_server_test.go @@ -0,0 +1,206 @@ +package distributed_execution + +import ( + "context" + "testing" + + "github.com/prometheus/prometheus/model/histogram" + "github.com/prometheus/prometheus/model/labels" + "github.com/prometheus/prometheus/promql" + "github.com/prometheus/prometheus/promql/parser" + v1 "github.com/prometheus/prometheus/web/api/v1" + "github.com/stretchr/testify/assert" + + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" +) + +// TestQuerierServer_Series tests series streaming +func TestQuerierServer_Series(t *testing.T) { + tests := []struct { + name string + setupCache func() *QueryTracker + request *querierpb.SeriesRequest + wantErr bool + errMessage string + }{ + { + name: "matrix data type success", + setupCache: func() *QueryTracker { + cache := NewQueryTracker() + matrix := promql.Matrix{ + promql.Series{ + Metric: labels.FromStrings("__name__", "foo"), + Floats: []promql.FPoint{{F: 1, T: 1000}}, + }, + promql.Series{ + Metric: labels.FromStrings("__name__", "bar"), + Floats: []promql.FPoint{{F: 2, T: 2000}}, + }, + } + cache.SetComplete(MakeFragmentKey(1, 1), &v1.QueryData{ + ResultType: parser.ValueTypeMatrix, + Result: matrix, + }) + return cache + }, + request: &querierpb.SeriesRequest{ + QueryID: 1, + FragmentID: 1, + }, + wantErr: false, + }, + { + name: "vector data type success", + setupCache: func() *QueryTracker { + cache := NewQueryTracker() + vector := promql.Vector{promql.Sample{}} + cache.SetComplete(MakeFragmentKey(1, 1), + &v1.QueryData{ + ResultType: parser.ValueTypeVector, + Result: vector, + }) + return cache + }, + request: &querierpb.SeriesRequest{ + QueryID: 1, + FragmentID: 1, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := tt.setupCache() + server := NewQuerierServer(cache) + + mockStream := &mockSeriesServer{} + err := server.Series(tt.request, mockStream) + + if tt.wantErr { + assert.Error(t, err) + if tt.errMessage != "" { + assert.Contains(t, err.Error(), tt.errMessage) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestQuerierServer_Next tests next streaming +func TestQuerierServer_Next(t *testing.T) { + tests := []struct { + name string + setupCache func() *QueryTracker + request *querierpb.NextRequest + wantErr bool + errMessage string + }{ + { + name: "matrix data type success", + setupCache: func() *QueryTracker { + cache := NewQueryTracker() + matrix := promql.Matrix{ + promql.Series{ + Metric: labels.FromStrings("__name__", "foo"), + Floats: []promql.FPoint{{F: 1, T: 1000}}, + }, + promql.Series{ + Metric: labels.FromStrings("__name__", "bar"), + Floats: []promql.FPoint{{F: 2, T: 2000}}, + }, + } + cache.SetComplete(MakeFragmentKey(1, 1), &v1.QueryData{ + ResultType: parser.ValueTypeMatrix, + Result: matrix, + }) + return cache + }, + request: &querierpb.NextRequest{ + QueryID: 1, + FragmentID: 1, + Batchsize: BATCHSIZE, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := tt.setupCache() + server := NewQuerierServer(cache) + + mockStream := &mockNextServer{} + err := server.Next(tt.request, mockStream) + + if tt.wantErr { + assert.Error(t, err) + if tt.errMessage != "" { + assert.Contains(t, err.Error(), tt.errMessage) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestFloatHistorgramConversion function is doing conversion correctly +func TestFloatHistogramConversion(t *testing.T) { + original := &histogram.FloatHistogram{ + Schema: 1, + ZeroThreshold: 0.001, + ZeroCount: 2.0, + Count: 10.0, + Sum: 100.0, + PositiveSpans: []histogram.Span{{Offset: 1, Length: 2}}, + PositiveBuckets: []float64{1.0, 2.0}, + NegativeSpans: []histogram.Span{{Offset: -2, Length: 1}}, + NegativeBuckets: []float64{-1.0}, + } + + proto := floatHistogramToFloatHistogramProto(original) + + result := floatHistogramProtoToFloatHistogram(*proto) + + assert.Equal(t, original.Schema, result.Schema) + assert.Equal(t, original.ZeroThreshold, result.ZeroThreshold) + assert.Equal(t, original.ZeroCount, result.ZeroCount) + assert.Equal(t, original.Count, result.Count) + assert.Equal(t, original.Sum, result.Sum) + assert.Equal(t, original.PositiveSpans, result.PositiveSpans) + assert.Equal(t, original.PositiveBuckets, result.PositiveBuckets) + assert.Equal(t, original.NegativeSpans, result.NegativeSpans) + assert.Equal(t, original.NegativeBuckets, result.NegativeBuckets) +} + +// mock implementations for testing +type mockSeriesServer struct { + querierpb.Querier_SeriesServer + sent []*querierpb.SeriesBatch +} + +func (m *mockSeriesServer) Send(batch *querierpb.SeriesBatch) error { + m.sent = append(m.sent, batch) + return nil +} + +func (m *mockSeriesServer) Context() context.Context { + return context.Background() +} + +type mockNextServer struct { + querierpb.Querier_NextServer + sent []*querierpb.StepVectorBatch +} + +func (m *mockNextServer) Send(batch *querierpb.StepVectorBatch) error { + m.sent = append(m.sent, batch) + return nil +} + +func (m *mockNextServer) Context() context.Context { + return context.Background() +} diff --git a/pkg/distributed_execution/query_tracker.go b/pkg/distributed_execution/query_tracker.go new file mode 100644 index 00000000000..c52fff72162 --- /dev/null +++ b/pkg/distributed_execution/query_tracker.go @@ -0,0 +1,127 @@ +package distributed_execution + +import ( + "sync" + "time" +) + +const ( + DefaultTTL = 1 * time.Minute +) + +// QueryTracker manages the lifecycle and state of query fragments during distributed execution. +// It provides thread-safe access to fragment results and their status. +type QueryTracker struct { + sync.RWMutex + cache map[FragmentKey]FragmentResult +} + +// FragmentStatus represents the current state of a query fragment. +type FragmentStatus string + +const ( + StatusWriting FragmentStatus = "writing" + StatusDone FragmentStatus = "done" + StatusError FragmentStatus = "error" +) + +// FragmentResult holds the result data and metadata for a query fragment. +type FragmentResult struct { + Data interface{} + Status FragmentStatus + Expiration time.Time +} + +// NewQueryTracker creates a new QueryTracker instance with an initialized cache. +func NewQueryTracker() *QueryTracker { + return &QueryTracker{ + cache: make(map[FragmentKey]FragmentResult), + } +} + +func (qt *QueryTracker) Size() int { + return len(qt.cache) +} + +// InitWriting initializes a new fragment entry with writing status. +func (qt *QueryTracker) InitWriting(key FragmentKey) { + qt.Lock() + defer qt.Unlock() + qt.cache[key] = FragmentResult{ + Status: StatusWriting, + Expiration: time.Now().Add(DefaultTTL), + } +} + +// SetComplete marks a fragment as complete with its result data. +func (qt *QueryTracker) SetComplete(key FragmentKey, data interface{}) { + qt.Lock() + defer qt.Unlock() + qt.cache[key] = FragmentResult{ + Data: data, + Status: StatusDone, + Expiration: time.Now().Add(DefaultTTL), + } +} + +// SetError marks a fragment as failed. +func (qt *QueryTracker) SetError(key FragmentKey) { + qt.Lock() + defer qt.Unlock() + qt.cache[key] = FragmentResult{ + Status: StatusError, + Expiration: time.Now().Add(DefaultTTL), + } +} + +// IsReady checks if a fragment has completed processing successfully. +func (qt *QueryTracker) IsReady(key FragmentKey) bool { + qt.RLock() + defer qt.RUnlock() + if result, ok := qt.cache[key]; ok { + return result.Status == StatusDone + } + return false +} + +// Get retrieves the fragment result and existence status for a given key. +func (qt *QueryTracker) Get(key FragmentKey) (FragmentResult, bool) { + qt.RLock() + defer qt.RUnlock() + result, ok := qt.cache[key] + return result, ok +} + +// GetFragmentStatus returns the current status of a fragment. +func (qt *QueryTracker) GetFragmentStatus(key FragmentKey) FragmentStatus { + qt.RLock() + defer qt.RUnlock() + result, ok := qt.cache[key] + if !ok { + return FragmentStatus("") + } + return result.Status +} + +// CleanExpired removes all expired fragment entries from the cache. +func (qt *QueryTracker) CleanExpired() { + qt.Lock() + defer qt.Unlock() + now := time.Now() + for key, result := range qt.cache { + if now.After(result.Expiration) { + delete(qt.cache, key) + } + } +} + +// ClearQuery removes all fragments associated with the specified query ID. +func (qt *QueryTracker) ClearQuery(queryID uint64) { + qt.Lock() + defer qt.Unlock() + for key := range qt.cache { + if key.queryID == queryID { + delete(qt.cache, key) + } + } +} diff --git a/pkg/distributed_execution/remote_node.go b/pkg/distributed_execution/remote_node.go index 04a146570c5..7998fcbdeda 100644 --- a/pkg/distributed_execution/remote_node.go +++ b/pkg/distributed_execution/remote_node.go @@ -1,11 +1,21 @@ package distributed_execution import ( + "context" "encoding/json" "fmt" + "io" + "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/promql/parser" + "github.com/prometheus/prometheus/storage" + "github.com/thanos-io/promql-engine/execution/exchange" + "github.com/thanos-io/promql-engine/execution/model" "github.com/thanos-io/promql-engine/logicalplan" + "github.com/thanos-io/promql-engine/query" + + "github.com/cortexproject/cortex/pkg/distributed_execution/querierpb" + "github.com/cortexproject/cortex/pkg/ring/client" ) const ( @@ -22,6 +32,8 @@ type Remote struct { FragmentKey FragmentKey FragmentAddr string + + clientPool *client.Pool } func NewRemoteNode(Expr logicalplan.Node) logicalplan.Node { @@ -45,6 +57,10 @@ func (r *Remote) ReturnType() parser.ValueType { } func (r *Remote) Type() logicalplan.NodeType { return RemoteNode } +func (r *Remote) InsertClientPool(clientPool *client.Pool) { + r.clientPool = clientPool +} + type remote struct { QueryID uint64 FragmentID uint64 @@ -69,3 +85,248 @@ func (r *Remote) UnmarshalJSON(data []byte) error { r.FragmentAddr = re.FragmentAddr return nil } + +// MakeExecutionOperator creates a distributed execution operator from a Remote node. +// This implements the logicalplan.UserDefinedExpr interface, allowing Remote nodes +// to be transformed into custom distributed execution operators during query processing. +func (r *Remote) MakeExecutionOperator( + ctx context.Context, + vectors *model.VectorPool, + opts *query.Options, + hints storage.SelectHints, +) (model.VectorOperator, error) { + pool := r.clientPool + + remoteExec, err := newDistributedRemoteExecution(ctx, pool, r.FragmentKey, opts) + if err != nil { + return nil, err + } + remoteExec.vectors = vectors + + return exchange.NewConcurrent(remoteExec, 2, opts), nil +} + +type DistributedRemoteExecution struct { + client querierpb.QuerierClient + + vectors *model.VectorPool + + mint int64 + maxt int64 + step int64 + currentStep int64 + numSteps int + + stream querierpb.Querier_NextClient + buffer []model.StepVector + bufferIndex int + + batchSize int64 + series []labels.Labels + fragmentKey FragmentKey + addr string + initialized bool // track if stream is initialized +} + +type QuerierAddrKey struct{} + +// newDistributedRemoteExecution creates a DistributedRemoteExecution operator that executes +// queries across distributed queriers. It implements Thanos engine's logical plan execution by: +// 1. Streaming series metadata to discover the data shape +// 2. Fetching actual data values via subsequent Next calls +// +// Unlike local execution, this operator retrieves data from remote querier processes, +// enabling distributed query processing across multiple nodes. +func newDistributedRemoteExecution(ctx context.Context, pool *client.Pool, fragmentKey FragmentKey, queryOpts *query.Options) (*DistributedRemoteExecution, error) { + + _, _, _, childIDToAddr, _ := ExtractFragmentMetaData(ctx) + + poolClient, err := pool.GetClientFor(childIDToAddr[fragmentKey.fragmentID]) + + if err != nil { + return nil, err + } + + client, ok := poolClient.(*querierClient) + if !ok { + return nil, fmt.Errorf("invalid client type from pool") + } + + d := &DistributedRemoteExecution{ + client: client, + + mint: queryOpts.Start.UnixMilli(), + maxt: queryOpts.End.UnixMilli(), + step: queryOpts.Step.Milliseconds(), + currentStep: queryOpts.Start.UnixMilli(), + numSteps: queryOpts.NumSteps(), + + batchSize: 1000, + fragmentKey: fragmentKey, + addr: childIDToAddr[fragmentKey.fragmentID], + buffer: []model.StepVector{}, + bufferIndex: 0, + initialized: false, + } + + if d.step == 0 { + d.step = 1 + } + + return d, nil +} + +func (d *DistributedRemoteExecution) Series(ctx context.Context) ([]labels.Labels, error) { + + if d.series != nil { + return d.series, nil + } + + req := &querierpb.SeriesRequest{ + QueryID: d.fragmentKey.queryID, + FragmentID: d.fragmentKey.fragmentID, + Batchsize: d.batchSize, + } + + stream, err := d.client.Series(ctx, req) + if err != nil { + return nil, err + } + + var series []labels.Labels + + for { + seriesBatch, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + for _, s := range seriesBatch.OneSeries { + oneSeries := make(map[string]string, len(s.Labels)) + for _, l := range s.Labels { + oneSeries[l.Name] = l.Value + } + series = append(series, labels.FromMap(oneSeries)) + } + } + + d.series = series + return series, nil +} + +func (d *DistributedRemoteExecution) Next(ctx context.Context) ([]model.StepVector, error) { + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if d.currentStep > d.maxt { + return nil, nil + } + + ts := d.currentStep + numVectorsNeeded := 0 + for currStep := 0; currStep < d.numSteps && ts <= d.maxt; currStep++ { + numVectorsNeeded++ + ts += d.step + } + + // return from buffer first + if d.buffer != nil && d.bufferIndex < len(d.buffer) { + end := d.bufferIndex + int(d.numSteps) + if end > len(d.buffer) { + end = len(d.buffer) + } + result := d.buffer[d.bufferIndex:end] + d.bufferIndex = end + + if d.bufferIndex >= len(d.buffer) { + d.buffer = nil + d.bufferIndex = 0 + } + + return result, nil + } + + // initialize stream if haven't + if !d.initialized { + req := &querierpb.NextRequest{ + QueryID: d.fragmentKey.queryID, + FragmentID: d.fragmentKey.fragmentID, + Batchsize: d.batchSize, + } + stream, err := d.client.Next(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to initialize stream: %w", err) + } + d.stream = stream + d.initialized = true + } + + // get new batch from server + batch, err := d.stream.Recv() + if err == io.EOF { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("error receiving from stream: %w", err) + } + + // return new batch and save it + d.buffer = make([]model.StepVector, len(batch.StepVectors)) + for i, sv := range batch.StepVectors { + d.buffer[i] = model.StepVector{ + T: sv.T, + SampleIDs: sv.Sample_IDs, + Samples: sv.Samples, + HistogramIDs: sv.Histogram_IDs, + Histograms: floatHistogramProtoToFloatHistograms(sv.Histograms), + } + } + + end := d.numSteps + if end > len(d.buffer) { + end = len(d.buffer) + } + result := d.buffer[:end] + d.bufferIndex = end + + if d.bufferIndex >= len(d.buffer) { + d.buffer = nil + d.bufferIndex = 0 + } + + d.currentStep += d.step * int64(len(result)) + + return result, nil +} + +func (d *DistributedRemoteExecution) Close() error { + if d.stream != nil { + + if err := d.stream.CloseSend(); err != nil { + return fmt.Errorf("error closing stream: %w", err) + } + } + d.buffer = nil + d.bufferIndex = 0 + d.initialized = false + return nil +} + +func (d DistributedRemoteExecution) GetPool() *model.VectorPool { + return d.vectors +} + +func (d DistributedRemoteExecution) Explain() (next []model.VectorOperator) { + return []model.VectorOperator{} +} + +func (d DistributedRemoteExecution) String() string { + return "DistributedRemoteExecution(" + d.addr + ")" +} diff --git a/pkg/querier/worker/scheduler_processor.go b/pkg/querier/worker/scheduler_processor.go index 3bba5980442..ad6ef1b161f 100644 --- a/pkg/querier/worker/scheduler_processor.go +++ b/pkg/querier/worker/scheduler_processor.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/health/grpc_health_v1" + "github.com/cortexproject/cortex/pkg/distributed_execution" "github.com/cortexproject/cortex/pkg/frontend/v2/frontendv2pb" querier_stats "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/ring/client" @@ -158,7 +159,9 @@ func (sp *schedulerProcessor) querierLoop(c schedulerpb.SchedulerForQuerier_Quer if request.StatsEnabled { level.Info(logger).Log("msg", "started running request") } - sp.runRequest(ctx, logger, request.QueryID, request.FrontendAddress, request.StatsEnabled, request.HttpRequest) + + ctx = distributed_execution.InjectFragmentMetaData(ctx, request.FragmentID, request.QueryID, request.IsRoot, request.ChildIDtoAddrs) + sp.runRequest(ctx, logger, request.QueryID, request.FrontendAddress, request.StatsEnabled, request.HttpRequest, request.IsRoot) if err = ctx.Err(); err != nil { return @@ -172,7 +175,7 @@ func (sp *schedulerProcessor) querierLoop(c schedulerpb.SchedulerForQuerier_Quer } } -func (sp *schedulerProcessor) runRequest(ctx context.Context, logger log.Logger, queryID uint64, frontendAddress string, statsEnabled bool, request *httpgrpc.HTTPRequest) { +func (sp *schedulerProcessor) runRequest(ctx context.Context, logger log.Logger, queryID uint64, frontendAddress string, statsEnabled bool, request *httpgrpc.HTTPRequest, isRoot bool) { var stats *querier_stats.QueryStats if statsEnabled { stats, ctx = querier_stats.ContextWithEmptyStats(ctx) @@ -189,6 +192,10 @@ func (sp *schedulerProcessor) runRequest(ctx context.Context, logger log.Logger, } } } + if !isRoot { + return + } + if statsEnabled { level.Info(logger).Log("msg", "finished request", "status_code", response.Code, "response_size", len(response.GetBody())) }