diff --git a/pkg/distsql/distsql.go b/pkg/distsql/distsql.go index 0701ad12b9cba..11bb4c17d588d 100644 --- a/pkg/distsql/distsql.go +++ b/pkg/distsql/distsql.go @@ -177,6 +177,12 @@ func SelectWithRuntimeStats(ctx context.Context, dctx *distsqlctx.DistSQLContext func Analyze(ctx context.Context, client kv.Client, kvReq *kv.Request, vars any, isRestrict bool, dctx *distsqlctx.DistSQLContext) (SelectResult, error) { ctx = WithSQLKvExecCounterInterceptor(ctx, dctx.KvExecCounter) + failpoint.Inject("mockAnalyzeRequestWaitForCancel", func(val failpoint.Value) { + if val.(bool) { + <-ctx.Done() + failpoint.Return(nil, ctx.Err()) + } + }) kvReq.RequestSource.RequestSourceInternal = true kvReq.RequestSource.RequestSourceType = kv.InternalTxnStats resp := client.Send(ctx, kvReq, vars, &kv.ClientSendOption{}) diff --git a/pkg/executor/analyze.go b/pkg/executor/analyze.go index 3858b7f97f953..473ac6d1d8487 100644 --- a/pkg/executor/analyze.go +++ b/pkg/executor/analyze.go @@ -46,6 +46,7 @@ import ( handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/sqlescape" "github.com/pingcap/tipb/go-tipb" @@ -101,6 +102,8 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() infoSchema := sessiontxn.GetTxnManager(e.Ctx()).GetTxnInfoSchema() sessionVars := e.Ctx().GetSessionVars() + ctx, stop := e.buildAnalyzeKillCtx(ctx) + defer stop() // Filter the locked tables. tasks, needAnalyzeTableCnt, skippedTables, err := filterAndCollectTasks(e.tasks, statsHandle, infoSchema) @@ -132,7 +135,7 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { taskCh := make(chan *analyzeTask, buildStatsConcurrency) resultsCh := make(chan *statistics.AnalyzeResults, 1) for range buildStatsConcurrency { - e.wg.Run(func() { e.analyzeWorker(taskCh, resultsCh) }) + e.wg.Run(func() { e.analyzeWorker(ctx, taskCh, resultsCh) }) } pruneMode := variable.PartitionPruneMode(sessionVars.PartitionPruneMode.Load()) // needGlobalStats used to indicate whether we should merge the partition-level stats to global-level stats. @@ -152,10 +155,12 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { dom.SysProcTracker().KillSysProcess(id) } }) + sentTasks := 0 TASKLOOP: for _, task := range tasks { select { case taskCh <- task: + sentTasks++ case <-e.errExitCh: break TASKLOOP case <-gctx.Done(): @@ -173,6 +178,17 @@ TASKLOOP: err = e.waitFinish(ctx, g, resultsCh) if err != nil { + if stderrors.Is(err, context.Canceled) { + if cause := context.Cause(ctx); cause != nil { + err = cause + } + } + for task := range taskCh { + finishJobWithLog(statsHandle, task.job, err) + } + for i := sentTasks; i < len(tasks); i++ { + finishJobWithLog(statsHandle, tasks[i].job, err) + } return err } @@ -469,6 +485,28 @@ func (e *AnalyzeExec) handleResultsErrorWithConcurrency( break } if results.Err != nil { + if intest.InTest && stderrors.Is(results.Err, context.Canceled) { + jobInfo := "" + dbName := "" + tableName := "" + partitionName := "" + if results.Job != nil { + jobInfo = results.Job.JobInfo + dbName = results.Job.DBName + tableName = results.Job.TableName + partitionName = results.Job.PartitionName + } + statslogutil.StatsLogger().Info("analyze result canceled", + zap.Uint32("killSignal", e.Ctx().GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", e.Ctx().GetSessionVars().ConnectionID), + zap.String("jobInfo", jobInfo), + zap.String("dbName", dbName), + zap.String("tableName", tableName), + zap.String("partitionName", partitionName), + zap.Error(results.Err), + zap.Stack("stack"), + ) + } err = results.Err if isAnalyzeWorkerPanic(err) { panicCnt++ @@ -503,7 +541,77 @@ func (e *AnalyzeExec) handleResultsErrorWithConcurrency( return err } -func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- *statistics.AnalyzeResults) { +func (e *AnalyzeExec) buildAnalyzeKillCtx(parent context.Context) (context.Context, func()) { + ctx, cancel := context.WithCancelCause(parent) + killer := &e.Ctx().GetSessionVars().SQLKiller + killCh := killer.GetKillEventChan() + stopCh := make(chan struct{}) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-stopCh: + return + case <-killCh: + status := killer.GetKillSignal() + if status == 0 { + return + } + err := killer.HandleSignal() + if err == nil { + err = exeerrors.ErrQueryInterrupted + } + cancel(err) + return + } + } + }() + return ctx, func() { + cancel(context.Canceled) + close(stopCh) + } +} + +func analyzeWorkerExitErr(ctx context.Context, errExitCh <-chan struct{}) error { + select { + case <-ctx.Done(): + if err := context.Cause(ctx); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + return exeerrors.ErrQueryInterrupted + case <-errExitCh: + return exeerrors.ErrQueryInterrupted + default: + return nil + } +} + +func (e *AnalyzeExec) sendAnalyzeResult(ctx context.Context, statsHandle *handle.Handle, resultsCh chan<- *statistics.AnalyzeResults, result *statistics.AnalyzeResults) { + select { + case resultsCh <- result: + return + case <-ctx.Done(): + case <-e.errExitCh: + } + err := result.Err + if err == nil { + err = context.Cause(ctx) + } + if err == nil { + err = ctx.Err() + } + if err == nil { + err = exeerrors.ErrQueryInterrupted + } + finishJobWithLog(statsHandle, result.Job, err) +} + +// ctx must be from AnalyzeExec.buildAnalyzeKillCtx +func (e *AnalyzeExec) analyzeWorker(ctx context.Context, taskCh <-chan *analyzeTask, resultsCh chan<- *statistics.AnalyzeResults) { var task *analyzeTask statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() defer func() { @@ -526,23 +634,22 @@ func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- var ok bool task, ok = <-taskCh if !ok { - break + return + } + if err := analyzeWorkerExitErr(ctx, e.errExitCh); err != nil { + finishJobWithLog(statsHandle, task.job, err) + return } failpoint.Inject("handleAnalyzeWorkerPanic", nil) - statsHandle.StartAnalyzeJob(task.job) switch task.taskType { case colTask: - select { - case <-e.errExitCh: - return - case resultsCh <- analyzeColumnsPushDownEntry(e.gp, task.colExec): - } + statsHandle.StartAnalyzeJob(task.job) + result := analyzeColumnsPushDownEntry(ctx, e.gp, task.colExec) + e.sendAnalyzeResult(ctx, statsHandle, resultsCh, result) case idxTask: - select { - case <-e.errExitCh: - return - case resultsCh <- analyzeIndexPushdown(task.idxExec): - } + statsHandle.StartAnalyzeJob(task.job) + result := analyzeIndexPushdown(ctx, task.idxExec) + e.sendAnalyzeResult(ctx, statsHandle, resultsCh, result) } } } diff --git a/pkg/executor/analyze_col.go b/pkg/executor/analyze_col.go index 2ae1ad2764bf6..0a9ee89d53b85 100644 --- a/pkg/executor/analyze_col.go +++ b/pkg/executor/analyze_col.go @@ -16,6 +16,7 @@ package executor import ( "context" + stderrors "errors" "fmt" "math" "strings" @@ -32,14 +33,17 @@ import ( "github.com/pingcap/tidb/pkg/planner/core" plannerutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/statistics" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/memory" "github.com/pingcap/tidb/pkg/util/ranger" "github.com/pingcap/tipb/go-tipb" "github.com/tiancaiamao/gp" + "go.uber.org/zap" ) // AnalyzeColumnsExec represents Analyze columns push down executor. @@ -64,11 +68,31 @@ type AnalyzeColumnsExec struct { memTracker *memory.Tracker } -func analyzeColumnsPushDownEntry(gp *gp.Pool, e *AnalyzeColumnsExec) *statistics.AnalyzeResults { +func analyzeColumnsPushDownEntry(ctx context.Context, gp *gp.Pool, e *AnalyzeColumnsExec) *statistics.AnalyzeResults { if e.AnalyzeInfo.StatsVersion >= statistics.Version2 { - return e.toV2().analyzeColumnsPushDownV2(gp) + res := e.toV2().analyzeColumnsPushDownV2(ctx, gp) + e.logAnalyzeCanceledInTest(ctx, res.Err, "analyze columns canceled") + return res } - return e.toV1().analyzeColumnsPushDownV1() + res := e.toV1().analyzeColumnsPushDownV1(ctx) + e.logAnalyzeCanceledInTest(ctx, res.Err, "analyze columns canceled") + return res +} + +func (e *AnalyzeColumnsExec) logAnalyzeCanceledInTest(ctx context.Context, err error, msg string) { + if !intest.InTest || err == nil || !stderrors.Is(err, context.Canceled) { + return + } + cause := context.Cause(ctx) + ctxErr := ctx.Err() + statslogutil.StatsLogger().Info(msg, + zap.Uint32("killSignal", e.ctx.GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", e.ctx.GetSessionVars().ConnectionID), + zap.Error(err), + zap.Error(cause), + zap.Error(ctxErr), + zap.Stack("stack"), + ) } func (e *AnalyzeColumnsExec) toV1() *AnalyzeColumnsExecV1 { @@ -83,12 +107,12 @@ func (e *AnalyzeColumnsExec) toV2() *AnalyzeColumnsExecV2 { } } -func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { +func (e *AnalyzeColumnsExec) open(ctx context.Context, ranges []*ranger.Range) error { e.memTracker = memory.NewTracker(int(e.ctx.GetSessionVars().PlanID.Load()), -1) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) e.resultHandler = &tableResultHandler{} firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(ranges, true, false, !hasPkHist(e.handleCols)) - firstResult, err := e.buildResp(firstPartRanges) + firstResult, err := e.buildResp(ctx, firstPartRanges) if err != nil { return err } @@ -97,7 +121,7 @@ func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { return nil } var secondResult distsql.SelectResult - secondResult, err = e.buildResp(secondPartRanges) + secondResult, err = e.buildResp(ctx, secondPartRanges) if err != nil { return err } @@ -106,7 +130,7 @@ func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { return nil } -func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectResult, error) { +func (e *AnalyzeColumnsExec) buildResp(ctx context.Context, ranges []*ranger.Range) (distsql.SelectResult, error) { var builder distsql.RequestBuilder reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges) builder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) @@ -130,16 +154,16 @@ func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectRe if err != nil { return nil, err } - ctx := context.TODO() result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) if err != nil { + e.logAnalyzeCanceledInTest(ctx, err, "analyze columns distsql canceled") return nil, err } return result, nil } -func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats bool) (hists []*statistics.Histogram, cms []*statistics.CMSketch, topNs []*statistics.TopN, fms []*statistics.FMSketch, extStats *statistics.ExtendedStatsColl, err error) { - if err = e.open(ranges); err != nil { +func (e *AnalyzeColumnsExec) buildStats(ctx context.Context, ranges []*ranger.Range, needExtStats bool) (hists []*statistics.Histogram, cms []*statistics.CMSketch, topNs []*statistics.TopN, fms []*statistics.FMSketch, extStats *statistics.ExtendedStatsColl, err error) { + if err = e.open(ctx, ranges); err != nil { return nil, nil, nil, nil, nil, err } defer func() { @@ -186,11 +210,19 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo return nil, nil, nil, nil, nil, err } failpoint.Inject("mockSlowAnalyzeV1", func() { - time.Sleep(1000 * time.Second) + select { + case <-ctx.Done(): + err := context.Cause(ctx) + if err == nil { + err = ctx.Err() + } + failpoint.Return(nil, nil, nil, nil, nil, err) + case <-time.After(1000 * time.Second): + } }) - data, err1 := e.resultHandler.nextRaw(context.TODO()) + data, err1 := e.resultHandler.nextRaw(ctx) if err1 != nil { - return nil, nil, nil, nil, nil, err1 + return nil, nil, nil, nil, nil, normalizeCtxErrWithCause(ctx, err1) } if data == nil { break @@ -310,7 +342,7 @@ type AnalyzeColumnsExecV1 struct { *AnalyzeColumnsExec } -func (e *AnalyzeColumnsExecV1) analyzeColumnsPushDownV1() *statistics.AnalyzeResults { +func (e *AnalyzeColumnsExecV1) analyzeColumnsPushDownV1(ctx context.Context) *statistics.AnalyzeResults { var ranges []*ranger.Range if hc := e.handleCols; hc != nil { if hc.IsInt() { @@ -322,7 +354,7 @@ func (e *AnalyzeColumnsExecV1) analyzeColumnsPushDownV1() *statistics.AnalyzeRes ranges = ranger.FullIntRange(false) } collExtStats := e.ctx.GetSessionVars().EnableExtendedStats - hists, cms, topNs, fms, extStats, err := e.buildStats(ranges, collExtStats) + hists, cms, topNs, fms, extStats, err := e.buildStats(ctx, ranges, collExtStats) if err != nil { return &statistics.AnalyzeResults{Err: err, Job: e.job} } diff --git a/pkg/executor/analyze_col_v2.go b/pkg/executor/analyze_col_v2.go index a0734e5b5399e..98f22e299b8eb 100644 --- a/pkg/executor/analyze_col_v2.go +++ b/pkg/executor/analyze_col_v2.go @@ -54,7 +54,7 @@ type AnalyzeColumnsExecV2 struct { *AnalyzeColumnsExec } -func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2(gp *gp.Pool) *statistics.AnalyzeResults { +func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2(ctx context.Context, gp *gp.Pool) *statistics.AnalyzeResults { var ranges []*ranger.Range if hc := e.handleCols; hc != nil { if hc.IsInt() { @@ -94,8 +94,8 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2(gp *gp.Pool) *statistics return &statistics.AnalyzeResults{Err: err, Job: e.job} } idxNDVPushDownCh := make(chan analyzeIndexNDVTotalResult, 1) - e.handleNDVForSpecialIndexes(specialIndexes, idxNDVPushDownCh, samplingStatsConcurrency) - count, hists, topNs, fmSketches, extStats, err := e.buildSamplingStats(gp, ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh, samplingStatsConcurrency) + e.handleNDVForSpecialIndexes(ctx, specialIndexes, idxNDVPushDownCh, samplingStatsConcurrency) + count, hists, topNs, fmSketches, extStats, err := e.buildSamplingStats(ctx, gp, ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh, samplingStatsConcurrency) if err != nil { e.memTracker.Release(e.memTracker.BytesConsumed()) return &statistics.AnalyzeResults{Err: err, Job: e.job} @@ -200,6 +200,7 @@ func printAnalyzeMergeCollectorLog(oldRootCount, newRootCount, subCount, tableID } func (e *AnalyzeColumnsExecV2) buildSamplingStats( + ctx context.Context, gp *gp.Pool, ranges []*ranger.Range, needExtStats bool, @@ -215,7 +216,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( err error, ) { // Open memory tracker and resultHandler. - if err = e.open(ranges); err != nil { + if err = e.open(ctx, ranges); err != nil { return 0, nil, nil, nil, nil, err } defer func() { @@ -235,6 +236,8 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( // Start workers to merge the result from collectors. mergeResultCh := make(chan *samplingMergeResult, 1) mergeTaskCh := make(chan []byte, 1) + taskCtx, taskCancel := context.WithCancelCause(ctx) + defer taskCancel(nil) var taskEg errgroup.Group // Start read data from resultHandler and send them to mergeTaskCh. taskEg.Go(func() (err error) { @@ -243,19 +246,23 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( err = getAnalyzePanicErr(r) } }() - return readDataAndSendTask(e.ctx, e.resultHandler, mergeTaskCh, e.memTracker) + err = readDataAndSendTask(taskCtx, e.ctx, e.resultHandler, mergeTaskCh, e.memTracker) + if err != nil { + taskCancel(err) + } + return err }) e.samplingMergeWg = &util.WaitGroupWrapper{} e.samplingMergeWg.Add(samplingStatsConcurrency) + mergeWorkerPanicCnt := 0 + mergeEg, mergeCtx := errgroup.WithContext(taskCtx) for i := range samplingStatsConcurrency { id := i gp.Go(func() { - e.subMergeWorker(mergeResultCh, mergeTaskCh, l, id) + e.subMergeWorker(mergeCtx, taskCtx, mergeResultCh, mergeTaskCh, l, id) }) } // Merge the result from collectors. - mergeWorkerPanicCnt := 0 - mergeEg, mergeCtx := errgroup.WithContext(context.Background()) mergeEg.Go(func() (err error) { defer func() { if r := recover(); r != nil { @@ -289,15 +296,36 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( }) err = taskEg.Wait() if err != nil { - mergeCtx.Done() + err = normalizeCtxErrWithCause(taskCtx, err) + if intest.InTest { + cause := context.Cause(taskCtx) + ctxErr := taskCtx.Err() + logutil.BgLogger().Info("analyze columns read task failed", + zap.Uint32("killSignal", e.ctx.GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", e.ctx.GetSessionVars().ConnectionID), + zap.Error(err), + zap.Bool("isCtxCanceled", stderrors.Is(err, context.Canceled)), + zap.Error(cause), + zap.Error(ctxErr), + zap.Stack("stack"), + ) + } if err1 := mergeEg.Wait(); err1 != nil { - err = stderrors.Join(err, err1) + err1 = normalizeCtxErrWithCause(taskCtx, err1) + if !stderrors.Is(err1, err) && err1.Error() != err.Error() { + err = stderrors.Join(err, err1) + } } return 0, nil, nil, nil, nil, getAnalyzePanicErr(err) } err = mergeEg.Wait() + if err != nil { + err = normalizeCtxErrWithCause(taskCtx, err) + } defer e.memTracker.Release(rootRowCollector.Base().MemSize) if err != nil { + taskCancel(err) + e.logAnalyzeCanceledInTest(mergeCtx, err, "analyze columns merge canceled") return 0, nil, nil, nil, nil, err } @@ -356,7 +384,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( // Start workers to build stats. for range samplingStatsConcurrency { e.samplingBuilderWg.Run(func() { - e.subBuildWorker(buildResultChan, buildTaskChan, hists, topns, sampleCollectors, exitCh) + e.subBuildWorker(ctx, buildResultChan, buildTaskChan, hists, topns, sampleCollectors, exitCh) }) } // Generate tasks for building stats. @@ -435,7 +463,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( } // handleNDVForSpecialIndexes deals with the logic to analyze the index containing the virtual column when the mode is full sampling. -func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.IndexInfo, totalResultCh chan analyzeIndexNDVTotalResult, samplingStatsConcurrency int) { +func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(ctx context.Context, indexInfos []*model.IndexInfo, totalResultCh chan analyzeIndexNDVTotalResult, samplingStatsConcurrency int) { defer func() { if r := recover(); r != nil { logutil.BgLogger().Warn("analyze ndv for special index panicked", zap.Any("recover", r), zap.Stack("stack")) @@ -447,8 +475,12 @@ func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.In }() tasks := e.buildSubIndexJobForSpecialIndex(indexInfos) taskCh := make(chan *analyzeTask, len(tasks)) + pendingJobs := make(map[uint64]*statistics.AnalyzeJob, len(tasks)) for _, task := range tasks { AddNewAnalyzeJob(e.ctx, task.job) + if task.job != nil && task.job.ID != nil { + pendingJobs[*task.job.ID] = task.job + } } resultsCh := make(chan *statistics.AnalyzeResults, len(tasks)) if len(tasks) < samplingStatsConcurrency { @@ -457,7 +489,7 @@ func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.In var subIndexWorkerWg = NewAnalyzeResultsNotifyWaitGroupWrapper(resultsCh) subIndexWorkerWg.Add(samplingStatsConcurrency) for range samplingStatsConcurrency { - subIndexWorkerWg.Run(func() { e.subIndexWorkerForNDV(taskCh, resultsCh) }) + subIndexWorkerWg.Run(func() { e.subIndexWorkerForNDV(ctx, taskCh, resultsCh) }) } for _, task := range tasks { taskCh <- task @@ -469,10 +501,14 @@ func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.In } var err error statsHandle := domain.GetDomain(e.ctx).StatsHandle() +LOOP: for panicCnt < samplingStatsConcurrency { results, ok := <-resultsCh if !ok { - break + break LOOP + } + if results.Job != nil && results.Job.ID != nil { + delete(pendingJobs, *results.Job.ID) } if results.Err != nil { err = results.Err @@ -480,11 +516,24 @@ func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.In if isAnalyzeWorkerPanic(err) { panicCnt++ } - continue + continue LOOP } statsHandle.FinishAnalyzeJob(results.Job, nil, statistics.TableAnalysisJob) totalResult.results[results.Ars[0].Hist[0].ID] = results } + if err == nil { + if ctxErr := ctx.Err(); ctxErr != nil { + err = context.Cause(ctx) + if err == nil { + err = ctxErr + } + } + } + if err != nil && len(pendingJobs) > 0 { + for _, job := range pendingJobs { + statsHandle.FinishAnalyzeJob(job, err, statistics.TableAnalysisJob) + } + } if err != nil { totalResult.err = err } @@ -492,7 +541,7 @@ func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.In } // subIndexWorker receive the task for each index and return the result for them. -func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(taskCh chan *analyzeTask, resultsCh chan *statistics.AnalyzeResults) { +func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(ctx context.Context, taskCh chan *analyzeTask, resultsCh chan *statistics.AnalyzeResults) { var task *analyzeTask statsHandle := domain.GetDomain(e.ctx).StatsHandle() defer func() { @@ -507,10 +556,15 @@ func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(taskCh chan *analyzeTask, re }() for { var ok bool - task, ok = <-taskCh - if !ok { - break + select { + case task, ok = <-taskCh: + if !ok { + return + } + case <-ctx.Done(): + return } + statsHandle.StartAnalyzeJob(task.job) if task.taskType != idxTask { resultsCh <- &statistics.AnalyzeResults{ @@ -520,7 +574,7 @@ func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(taskCh chan *analyzeTask, re continue } task.idxExec.job = task.job - resultsCh <- analyzeIndexNDVPushDown(task.idxExec) + resultsCh <- analyzeIndexNDVPushDown(ctx, task.idxExec) } } @@ -589,7 +643,7 @@ func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*mod return tasks } -func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) { +func (e *AnalyzeColumnsExecV2) subMergeWorker(ctx context.Context, parentCtx context.Context, resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) { // Only close the resultCh in the first worker. closeTheResultCh := index == 0 defer func() { @@ -625,54 +679,102 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu for range l { retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) } + cleanupCollector := func() { + // Ensure collector resources are released on early exit paths. + retCollector.DestroyAndPutToPool() + } statsHandle := domain.GetDomain(e.ctx).StatsHandle() for { - data, ok := <-taskCh - if !ok { - break - } - - // Unmarshal the data. - dataSize := int64(cap(data)) - colResp := &tipb.AnalyzeColumnsResp{} - err := colResp.Unmarshal(data) - if err != nil { - resultCh <- &samplingMergeResult{err: err} + select { + case data, ok := <-taskCh: + if !ok { + resultCh <- &samplingMergeResult{collector: retCollector} + return + } + // Unmarshal the data. + dataSize := int64(cap(data)) + colResp := &tipb.AnalyzeColumnsResp{} + err := colResp.Unmarshal(data) + if err != nil { + cleanupCollector() + resultCh <- &samplingMergeResult{err: err} + return + } + // Consume the memory of the data. + colRespSize := int64(colResp.Size()) + e.memTracker.Consume(colRespSize) + + // Update processed rows. + subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) + subCollector.Base().FromProto(colResp.RowCollector, e.memTracker) + statsHandle.UpdateAnalyzeJobProgress(e.job, subCollector.Base().Count) + + // Print collect log. + oldRetCollectorSize := retCollector.Base().MemSize + oldRetCollectorCount := retCollector.Base().Count + retCollector.MergeCollector(subCollector) + newRetCollectorCount := retCollector.Base().Count + printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count, + e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(), + "merge subCollector in concurrency in AnalyzeColumnsExecV2", index) + + // Consume the memory of the result. + newRetCollectorSize := retCollector.Base().MemSize + subCollectorSize := subCollector.Base().MemSize + e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize) + e.memTracker.Release(dataSize + colRespSize) + subCollector.DestroyAndPutToPool() + case <-ctx.Done(): + err := context.Cause(ctx) + if (err == nil || stderrors.Is(err, context.Canceled)) && parentCtx != nil { + parentErr := context.Cause(parentCtx) + if parentErr != nil { + err = parentErr + } + } + if err != nil { + e.logAnalyzeCanceledInTest(ctx, err, "analyze columns subMergeWorker canceled") + cleanupCollector() + resultCh <- &samplingMergeResult{err: err} + return + } + err = ctx.Err() + if err != nil { + e.logAnalyzeCanceledInTest(ctx, err, "analyze columns subMergeWorker canceled") + cleanupCollector() + resultCh <- &samplingMergeResult{err: err} + return + } + if intest.InTest { + panic("this ctx should be canceled with the error") + } + cleanupCollector() + resultCh <- &samplingMergeResult{err: errors.New("context canceled without error")} return } - // Consume the memory of the data. - colRespSize := int64(colResp.Size()) - e.memTracker.Consume(colRespSize) - - // Update processed rows. - subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) - subCollector.Base().FromProto(colResp.RowCollector, e.memTracker) - statsHandle.UpdateAnalyzeJobProgress(e.job, subCollector.Base().Count) - - // Print collect log. - oldRetCollectorSize := retCollector.Base().MemSize - oldRetCollectorCount := retCollector.Base().Count - retCollector.MergeCollector(subCollector) - newRetCollectorCount := retCollector.Base().Count - printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count, - e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(), - "merge subCollector in concurrency in AnalyzeColumnsExecV2", index) - - // Consume the memory of the result. - newRetCollectorSize := retCollector.Base().MemSize - subCollectorSize := subCollector.Base().MemSize - e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize) - e.memTracker.Release(dataSize + colRespSize) - subCollector.DestroyAndPutToPool() - } - - resultCh <- &samplingMergeResult{collector: retCollector} + } +} + +func (e *AnalyzeColumnsExecV2) logAnalyzeCanceledInTest(ctx context.Context, err error, msg string) { + if !intest.InTest || err == nil || !stderrors.Is(err, context.Canceled) { + return + } + cause := context.Cause(ctx) + ctxErr := ctx.Err() + logutil.BgLogger().Info(msg, + zap.Uint32("killSignal", e.ctx.GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", e.ctx.GetSessionVars().ConnectionID), + zap.Error(err), + zap.Error(cause), + zap.Error(ctxErr), + zap.Stack("stack"), + ) } -func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan *samplingBuildTask, hists []*statistics.Histogram, topns []*statistics.TopN, collectors []*statistics.SampleCollector, exitCh chan struct{}) { +func (e *AnalyzeColumnsExecV2) subBuildWorker(ctx context.Context, resultCh chan error, taskCh chan *samplingBuildTask, hists []*statistics.Histogram, topns []*statistics.TopN, collectors []*statistics.SampleCollector, exitCh chan struct{}) { defer func() { if r := recover(); r != nil { - logutil.BgLogger().Warn("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) + logutil.BgLogger().Warn("analyze subBuildWorker panicked", zap.Any("recover", r), zap.Stack("stack")) metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() resultCh <- getAnalyzePanicErr(r) } @@ -841,6 +943,8 @@ workLoop: releaseCollectorMemory() case <-exitCh: return + case <-ctx.Done(): + return } } } @@ -863,25 +967,47 @@ type samplingBuildTask struct { slicePos int } -func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, mergeTaskCh chan []byte, memTracker *memory.Tracker) error { +func readDataAndSendTask(ctx context.Context, sctx sessionctx.Context, handler *tableResultHandler, mergeTaskCh chan []byte, memTracker *memory.Tracker) error { // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent. defer close(mergeTaskCh) for { failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { - dom := domain.GetDomain(ctx) + dom := domain.GetDomain(sctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } }) - if err := ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + if err := sctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return err } failpoint.Inject("mockSlowAnalyzeV2", func() { - time.Sleep(1000 * time.Second) + select { + case <-ctx.Done(): + err := context.Cause(ctx) + if err == nil { + err = ctx.Err() + } + failpoint.Return(err) + case <-time.After(1000 * time.Second): + } }) - data, err := handler.nextRaw(context.TODO()) + data, err := handler.nextRaw(ctx) if err != nil { + err = normalizeCtxErrWithCause(ctx, err) + if intest.InTest { + cause := context.Cause(ctx) + ctxErr := ctx.Err() + logutil.BgLogger().Info("analyze columns nextRaw failed", + zap.Uint32("killSignal", sctx.GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", sctx.GetSessionVars().ConnectionID), + zap.Error(err), + zap.Bool("isCtxCanceled", stderrors.Is(err, context.Canceled)), + zap.Error(cause), + zap.Error(ctxErr), + zap.Stack("stack"), + ) + } return errors.Trace(err) } if data == nil { diff --git a/pkg/executor/analyze_idx.go b/pkg/executor/analyze_idx.go index 45834a73bb6d7..a4ec54e6a8e97 100644 --- a/pkg/executor/analyze_idx.go +++ b/pkg/executor/analyze_idx.go @@ -16,6 +16,7 @@ package executor import ( "context" + stderrors "errors" "math" "time" @@ -29,8 +30,10 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/ranger" "github.com/pingcap/tipb/go-tipb" @@ -47,7 +50,7 @@ type AnalyzeIndexExec struct { countNullRes distsql.SelectResult } -func analyzeIndexPushdown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { +func analyzeIndexPushdown(ctx context.Context, idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { ranges := ranger.FullRange() // For single-column index, we do not load null rows from TiKV, so the built histogram would not include // null values, and its `NullCount` would be set by result of another distsql call to get null rows. @@ -57,8 +60,9 @@ func analyzeIndexPushdown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults if len(idxExec.idxInfo.Columns) == 1 { ranges = ranger.FullNotNullRange() } - hist, cms, fms, topN, err := idxExec.buildStats(ranges, true) + hist, cms, fms, topN, err := idxExec.buildStats(ctx, ranges, true) if err != nil { + idxExec.logAnalyzeCanceledInTest(ctx, err, "analyze index canceled") return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} } var statsVer = statistics.Version1 @@ -95,8 +99,8 @@ func analyzeIndexPushdown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults return result } -func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) (hist *statistics.Histogram, cms *statistics.CMSketch, fms *statistics.FMSketch, topN *statistics.TopN, err error) { - if err = e.open(ranges, considerNull); err != nil { +func (e *AnalyzeIndexExec) buildStats(ctx context.Context, ranges []*ranger.Range, considerNull bool) (hist *statistics.Histogram, cms *statistics.CMSketch, fms *statistics.FMSketch, topN *statistics.TopN, err error) { + if err = e.open(ctx, ranges, considerNull); err != nil { return nil, nil, nil, nil, err } defer func() { @@ -105,12 +109,12 @@ func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) err = err1 } }() - hist, cms, fms, topN, err = e.buildStatsFromResult(e.result, true) + hist, cms, fms, topN, err = e.buildStatsFromResult(ctx, e.result, true) if err != nil { return nil, nil, nil, nil, err } if e.countNullRes != nil { - nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) + nullHist, _, _, _, err := e.buildStatsFromResult(ctx, e.countNullRes, false) if err != nil { return nil, nil, nil, nil, err } @@ -122,14 +126,14 @@ func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) return hist, cms, fms, topN, nil } -func (e *AnalyzeIndexExec) open(ranges []*ranger.Range, considerNull bool) error { - err := e.fetchAnalyzeResult(ranges, false) +func (e *AnalyzeIndexExec) open(ctx context.Context, ranges []*ranger.Range, considerNull bool) error { + err := e.fetchAnalyzeResult(ctx, ranges, false) if err != nil { return err } if considerNull && len(e.idxInfo.Columns) == 1 { ranges = ranger.NullRange() - err = e.fetchAnalyzeResult(ranges, true) + err = e.fetchAnalyzeResult(ctx, ranges, true) if err != nil { return err } @@ -140,7 +144,7 @@ func (e *AnalyzeIndexExec) open(ranges []*ranger.Range, considerNull bool) error // fetchAnalyzeResult builds and dispatches the `kv.Request` from given ranges, and stores the `SelectResult` // in corresponding fields based on the input `isNullRange` argument, which indicates if the range is the // special null range for single-column index to get the null count. -func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRange bool) error { +func (e *AnalyzeIndexExec) fetchAnalyzeResult(ctx context.Context, ranges []*ranger.Range, isNullRange bool) error { var builder distsql.RequestBuilder var kvReqBuilder *distsql.RequestBuilder if e.isCommonHandle && e.idxInfo.Primary { @@ -166,9 +170,9 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang if err != nil { return err } - ctx := context.TODO() result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) if err != nil { + e.logAnalyzeCanceledInTest(ctx, err, "analyze index distsql canceled") return err } if isNullRange { @@ -179,7 +183,7 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang return nil } -func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { +func (e *AnalyzeIndexExec) buildStatsFromResult(killerCtx context.Context, result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, nil, nil, nil, errors.New("mock buildStatsFromResult error")) @@ -204,14 +208,30 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee dom.SysProcTracker().KillSysProcess(id) } }) - if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + select { + case <-killerCtx.Done(): + err := context.Cause(killerCtx) + if err == nil { + err = killerCtx.Err() + } return nil, nil, nil, nil, err + default: } failpoint.Inject("mockSlowAnalyzeIndex", func() { - time.Sleep(1000 * time.Second) + select { + case <-killerCtx.Done(): + err := context.Cause(killerCtx) + if err == nil { + err = killerCtx.Err() + } + failpoint.Return(nil, nil, nil, nil, err) + case <-time.After(1000 * time.Second): + } }) - data, err := result.NextRaw(context.TODO()) + data, err := result.NextRaw(killerCtx) if err != nil { + err = normalizeCtxErrWithCause(killerCtx, err) + e.logAnalyzeCanceledInTest(killerCtx, err, "analyze index nextRaw canceled") return nil, nil, nil, nil, err } if data == nil { @@ -240,8 +260,24 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee return hist, cms, fms, topn, nil } -func (e *AnalyzeIndexExec) buildSimpleStats(ranges []*ranger.Range, considerNull bool) (fms *statistics.FMSketch, nullHist *statistics.Histogram, err error) { - if err = e.open(ranges, considerNull); err != nil { +func (e *AnalyzeIndexExec) logAnalyzeCanceledInTest(ctx context.Context, err error, msg string) { + if !intest.InTest || err == nil || !stderrors.Is(err, context.Canceled) { + return + } + cause := context.Cause(ctx) + ctxErr := ctx.Err() + statslogutil.StatsLogger().Info(msg, + zap.Uint32("killSignal", e.ctx.GetSessionVars().SQLKiller.GetKillSignal()), + zap.Uint64("connID", e.ctx.GetSessionVars().ConnectionID), + zap.Error(err), + zap.Error(cause), + zap.Error(ctxErr), + zap.Stack("stack"), + ) +} + +func (e *AnalyzeIndexExec) buildSimpleStats(killerCtx context.Context, ranges []*ranger.Range, considerNull bool) (fms *statistics.FMSketch, nullHist *statistics.Histogram, err error) { + if err = e.open(killerCtx, ranges, considerNull); err != nil { return nil, nil, err } defer func() { @@ -250,9 +286,9 @@ func (e *AnalyzeIndexExec) buildSimpleStats(ranges []*ranger.Range, considerNull err = err1 } }() - _, _, fms, _, err = e.buildStatsFromResult(e.result, false) + _, _, fms, _, err = e.buildStatsFromResult(killerCtx, e.result, false) if e.countNullRes != nil { - nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) + nullHist, _, _, _, err := e.buildStatsFromResult(killerCtx, e.countNullRes, false) if err != nil { return nil, nil, err } @@ -263,7 +299,7 @@ func (e *AnalyzeIndexExec) buildSimpleStats(ranges []*ranger.Range, considerNull return fms, nil, nil } -func analyzeIndexNDVPushDown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { +func analyzeIndexNDVPushDown(killerCtx context.Context, idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { ranges := ranger.FullRange() // For single-column index, we do not load null rows from TiKV, so the built histogram would not include // null values, and its `NullCount` would be set by result of another distsql call to get null rows. @@ -273,7 +309,7 @@ func analyzeIndexNDVPushDown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResul if len(idxExec.idxInfo.Columns) == 1 { ranges = ranger.FullNotNullRange() } - fms, nullHist, err := idxExec.buildSimpleStats(ranges, len(idxExec.idxInfo.Columns) == 1) + fms, nullHist, err := idxExec.buildSimpleStats(killerCtx, ranges, len(idxExec.idxInfo.Columns) == 1) if err != nil { return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} } diff --git a/pkg/executor/analyze_utils.go b/pkg/executor/analyze_utils.go index b1b0b45ac2854..db62f30dd37cb 100644 --- a/pkg/executor/analyze_utils.go +++ b/pkg/executor/analyze_utils.go @@ -16,6 +16,7 @@ package executor import ( "context" + stderrors "errors" "strconv" "sync" @@ -108,6 +109,18 @@ func getAnalyzePanicErr(r any) error { return errors.Trace(errAnalyzeWorkerPanic) } +func normalizeCtxErrWithCause(ctx context.Context, err error) error { + if err == nil { + return nil + } + if stderrors.Is(err, context.Canceled) || stderrors.Is(err, context.DeadlineExceeded) { + if cause := context.Cause(ctx); cause != nil { + return cause + } + } + return err +} + // analyzeResultsNotifyWaitGroupWrapper is a wrapper for sync.WaitGroup // Please add all goroutine count when to `Add` to avoid exiting in advance. type analyzeResultsNotifyWaitGroupWrapper struct { diff --git a/pkg/executor/table_reader.go b/pkg/executor/table_reader.go index dc150673f7bf7..d375e79c01776 100644 --- a/pkg/executor/table_reader.go +++ b/pkg/executor/table_reader.go @@ -690,7 +690,7 @@ func (tr *tableResultHandler) nextRaw(ctx context.Context) (data []byte, err err if !tr.optionalFinished { data, err = tr.optionalResult.NextRaw(ctx) if err != nil { - return nil, err + return nil, normalizeCtxErrWithCause(ctx, err) } if data != nil { return data, nil @@ -699,7 +699,7 @@ func (tr *tableResultHandler) nextRaw(ctx context.Context) (data []byte, err err } data, err = tr.result.NextRaw(ctx) if err != nil { - return nil, err + return nil, normalizeCtxErrWithCause(ctx, err) } return data, nil } diff --git a/pkg/executor/test/analyzetest/analyze_test.go b/pkg/executor/test/analyzetest/analyze_test.go index 4dbee17dc9303..1c18f955c0d8c 100644 --- a/pkg/executor/test/analyzetest/analyze_test.go +++ b/pkg/executor/test/analyzetest/analyze_test.go @@ -141,6 +141,43 @@ func TestAnalyzeRestrict(t *testing.T) { rs, err := tk.Session().ExecuteInternal(ctx, "analyze table t") require.Nil(t, err) require.Nil(t, rs) + tk.MustExec("truncate table mysql.analyze_jobs") + t.Run("cancel_on_ctx", func(t *testing.T) { + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values (1), (2)") + tk.MustExec("set @@tidb_analyze_version = 2") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/distsql/mockAnalyzeRequestWaitForCancel", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/distsql/mockAnalyzeRequestWaitForCancel")) + }() + + baseCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + ctx, cancel := context.WithCancel(baseCtx) + done := make(chan error, 1) + go func() { + _, err := tk.Session().ExecuteInternal(ctx, "analyze table t") + done <- err + }() + + select { + case err := <-done: + t.Fatalf("analyze finished before cancel, err=%v", err) + case <-time.After(50 * time.Millisecond): + } + cancel() + + select { + case <-done: + rows := tk.MustQuery("select state, fail_reason from mysql.analyze_jobs where table_name = 't' order by end_time desc limit 1").Rows() + require.Len(t, rows, 1) + require.Equal(t, "failed", strings.ToLower(rows[0][0].(string))) + require.Contains(t, rows[0][1].(string), "context canceled") + case <-time.After(5 * time.Second): + t.Fatal("analyze does not stop after context canceled") + } + }) } func TestAnalyzeParameters(t *testing.T) { @@ -1958,6 +1995,7 @@ func testKillAutoAnalyze(t *testing.T, ver int) { }() } require.True(t, h.HandleAutoAnalyze(), comment) + require.NoError(t, h.Update(context.Background(), is)) currentVersion := h.GetPhysicalTableStats(tableInfo.ID, tableInfo).Version if status == "finished" { // If we kill a finished job, after kill command the status is still finished and the table stats are updated. @@ -2034,6 +2072,7 @@ func TestKillAutoAnalyzeIndex(t *testing.T) { }() } require.True(t, h.HandleAutoAnalyze(), comment) + require.NoError(t, h.Update(context.Background(), is)) currentVersion := h.GetPhysicalTableStats(tblInfo.ID, tblInfo).Version if status == "finished" { // If we kill a finished job, after kill command the status is still finished and the index stats are updated. diff --git a/pkg/executor/test/analyzetest/memorycontrol/memory_control_test.go b/pkg/executor/test/analyzetest/memorycontrol/memory_control_test.go index 66462343c208d..bb9ff3cd90525 100644 --- a/pkg/executor/test/analyzetest/memorycontrol/memory_control_test.go +++ b/pkg/executor/test/analyzetest/memorycontrol/memory_control_test.go @@ -98,7 +98,7 @@ func TestGlobalMemoryControlForPrepareAnalyze(t *testing.T) { require.NoError(t, err0) _, err1 := tk0.Exec(sqlExecute) // Killed and the WarnMsg is WarnMsgSuffixForInstance instead of WarnMsgSuffixForSingleQuery - require.True(t, strings.Contains(err1.Error(), "Your query has been cancelled due to exceeding the allowed memory limit for the tidb-server instance and this query is currently using the most memory. Please try narrowing your query scope or increase the tidb_server_memory_limit and try again.")) + require.True(t, strings.Contains(err1.Error(), "Your query has been cancelled due to exceeding the allowed memory limit for the tidb-server instance and this query is currently using the most memory. Please try narrowing your query scope or increase the tidb_server_memory_limit and try again."), err1.Error()) runtime.GC() require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/memory/ReadMemStats")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/mockAnalyzeMergeWorkerSlowConsume")) diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go b/pkg/statistics/handle/autoanalyze/autoanalyze.go index 5d92137e24c53..e83fb3132b1aa 100644 --- a/pkg/statistics/handle/autoanalyze/autoanalyze.go +++ b/pkg/statistics/handle/autoanalyze/autoanalyze.go @@ -845,34 +845,48 @@ func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy } job.EndTime = time.Now() - var sql string - var args []any + setStartTime := false + if job.StartTime.IsZero() { + // If the job is killed before it starts, ensure start_time is set for display. + job.StartTime = job.EndTime + setStartTime = true + } // process_id is used to see which process is running the analyze job and kill the analyze job. After the analyze job // is finished(or failed), process_id is useless and we set it to NULL to avoid `kill tidb process_id` wrongly. + state := statistics.AnalyzeFinished + failReason := "" if analyzeErr != nil { - failReason := analyzeErr.Error() + state = statistics.AnalyzeFailed + failReason = analyzeErr.Error() const textMaxLength = 65535 if len(failReason) > textMaxLength { failReason = failReason[:textMaxLength] } + } - if analyzeType == statistics.TableAnalysisJob { - sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" - args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} - } else { - sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" - args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} - } - } else { - if analyzeType == statistics.TableAnalysisJob { - sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" - args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} - } else { - sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" - args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} - } + setClauses := make([]string, 0, 6) + args := make([]any, 0, 6) + if setStartTime { + setClauses = append(setClauses, "start_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)") + args = append(args, job.StartTime.UTC().Format(types.TimeFormat)) } + if analyzeType == statistics.TableAnalysisJob { + setClauses = append(setClauses, "processed_rows = processed_rows + %?") + args = append(args, job.Progress.GetDeltaCount()) + } + setClauses = append(setClauses, "end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)") + args = append(args, job.EndTime.UTC().Format(types.TimeFormat)) + setClauses = append(setClauses, "state = %?") + args = append(args, state) + if analyzeErr != nil { + setClauses = append(setClauses, "fail_reason = %?") + args = append(args, failReason) + } + setClauses = append(setClauses, "process_id = NULL") + + sql := fmt.Sprintf("UPDATE mysql.analyze_jobs SET %s WHERE id = %%?", strings.Join(setClauses, ", ")) + args = append(args, *job.ID) _, _, err := statsutil.ExecRows(sctx, sql, args...) if err != nil {