diff --git a/openapi3.yaml b/openapi3.yaml index 5a1b7953..c1b94729 100644 --- a/openapi3.yaml +++ b/openapi3.yaml @@ -1572,6 +1572,21 @@ paths: application/json: schema: $ref: '#/components/schemas/taskNotFoundResponse' + '409': + description: >- + task was claimed by another worker. + This occurs when multiple workers attempt to dequeue the same task simultaneously. + The client should retry the dequeue operation to get a different task. + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/baseErrorResponse' + - type: object + properties: + code: + enum: + - TASK_STATUS_UPDATE_FAILED '500': description: Internal server error or invalid state transition content: @@ -1820,6 +1835,21 @@ paths: application/json: schema: $ref: '#/components/schemas/taskNotFoundResponse' + '409': + description: >- + task status was modified by another request. + This occurs when multiple workers attempt to update the same task simultaneously. + The current state of the task has changed since it was retrieved. + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/baseErrorResponse' + - type: object + properties: + code: + enum: + - TASK_STATUS_UPDATE_FAILED '500': description: Internal server error content: diff --git a/src/api/v1/tasks/controller.ts b/src/api/v1/tasks/controller.ts index c170808f..17e5a933 100644 --- a/src/api/v1/tasks/controller.ts +++ b/src/api/v1/tasks/controller.ts @@ -117,9 +117,11 @@ export class TaskControllerV1 { } catch (err) { if (err instanceof TaskNotFoundError) { (err as HttpError).status = httpStatus.NOT_FOUND; + } else if (err instanceof TaskStatusUpdateFailedError) { + // Race condition: resource was modified by another request + (err as HttpError).status = httpStatus.CONFLICT; } else if (badRequestErrors.some((e) => err instanceof e)) { (err as HttpError).status = httpStatus.BAD_REQUEST; - this.logger.error({ msg: `Task status update failed: invalid status transition`, status: req.body.status, err }); } return next(err); @@ -134,6 +136,9 @@ export class TaskControllerV1 { } catch (err) { if (err instanceof TaskNotFoundError) { (err as HttpError).status = httpStatus.NOT_FOUND; + } else if (err instanceof TaskStatusUpdateFailedError) { + // Race condition: another worker already dequeued this task + (err as HttpError).status = httpStatus.CONFLICT; } else if (internalErrors.some((e) => err instanceof e)) { (err as HttpError).status = httpStatus.INTERNAL_SERVER_ERROR; } diff --git a/src/common/constants.ts b/src/common/constants.ts index 0af9d8f5..5ff67efd 100644 --- a/src/common/constants.ts +++ b/src/common/constants.ts @@ -10,6 +10,7 @@ type SuccessMessagesObj = { export const SERVICE_NAME = readPackageJsonSync().name ?? 'unknown_service'; export const DEFAULT_SERVER_PORT = 80; export const DB_CONNECTION_TIMEOUT = 5000; +export const TX_TIMEOUT_MS = 15000; export const NODE_VERSION = process.versions.node; export const IGNORED_OUTGOING_TRACE_ROUTES = [/^.*\/v1\/metrics.*$/]; diff --git a/src/db/createConnection.ts b/src/db/createConnection.ts index 7f06ceb8..21cce3d0 100644 --- a/src/db/createConnection.ts +++ b/src/db/createConnection.ts @@ -3,6 +3,7 @@ import { hostname } from 'node:os'; import { commonDbFullV1Type } from '@map-colonies/schemas'; import type { PoolConfig } from 'pg'; import { PrismaPg } from '@prisma/adapter-pg'; +import { TX_TIMEOUT_MS } from '@src/common/constants'; import { PrismaClient } from '../db/prisma/generated/client'; interface SchemaExistsResult { @@ -37,7 +38,12 @@ export const createConnectionOptions = (dbConfig: DbConfig): PoolConfig => { // eslint-disable-next-line @typescript-eslint/explicit-function-return-type export function createPrismaClient(poolConfig: PoolConfig, schema: string) { const adapter = new PrismaPg(poolConfig, { schema }); - const prisma = new PrismaClient({ adapter }).$extends({ + const prisma = new PrismaClient({ + adapter, + transactionOptions: { + timeout: TX_TIMEOUT_MS, + }, + }).$extends({ query: { // eslint-disable-next-line @typescript-eslint/promise-function-async $allOperations({ args, query }) { diff --git a/src/openapi.d.ts b/src/openapi.d.ts index afe23ccc..d881c319 100644 --- a/src/openapi.d.ts +++ b/src/openapi.d.ts @@ -1839,6 +1839,18 @@ export interface operations { 'application/json': components['schemas']['taskNotFoundResponse']; }; }; + /** @description Race condition detected: task was claimed by another worker. This occurs when multiple workers attempt to dequeue the same task simultaneously. The client should retry the dequeue operation to get a different task. */ + 409: { + headers: { + [name: string]: unknown; + }; + content: { + 'application/json': components['schemas']['baseErrorResponse'] & { + /** @enum {unknown} */ + code?: 'TASK_STATUS_UPDATE_FAILED'; + }; + }; + }; /** @description Internal server error or invalid state transition */ 500: { headers: { @@ -2065,6 +2077,18 @@ export interface operations { 'application/json': components['schemas']['taskNotFoundResponse']; }; }; + /** @description Race condition detected: task status was modified by another request. This occurs when multiple workers attempt to update the same task simultaneously. The current state of the task has changed since it was retrieved. */ + 409: { + headers: { + [name: string]: unknown; + }; + content: { + 'application/json': components['schemas']['baseErrorResponse'] & { + /** @enum {unknown} */ + code?: 'TASK_STATUS_UPDATE_FAILED'; + }; + }; + }; /** @description Internal server error */ 500: { headers: { diff --git a/src/stages/models/manager.ts b/src/stages/models/manager.ts index 7a2a3aac..8defa689 100644 --- a/src/stages/models/manager.ts +++ b/src/stages/models/manager.ts @@ -7,7 +7,7 @@ import { INFRA_CONVENTIONS } from '@map-colonies/semantic-conventions'; import type { PrismaClient } from '@prismaClient'; import { JobOperationStatus, Prisma, StageOperationStatus } from '@prismaClient'; import { JobManager } from '@src/jobs/models/manager'; -import { SERVICES, XSTATE_DONE_STATE } from '@common/constants'; +import { SERVICES, TX_TIMEOUT_MS, XSTATE_DONE_STATE } from '@common/constants'; import { resolveTraceContext } from '@src/common/utils/tracingHelpers'; import { jobStateMachine } from '@src/jobs/models/jobStateMachine'; import { illegalStatusTransitionErrorMessage, prismaKnownErrors } from '@src/common/errors'; @@ -242,9 +242,12 @@ export class StageManager { }); if (!tx) { - return this.prisma.$transaction(async (newTx) => { - await this.executeUpdateStatus(stageId, status, newTx); - }); + return this.prisma.$transaction( + async (newTx) => { + await this.executeUpdateStatus(stageId, status, newTx); + }, + { timeout: TX_TIMEOUT_MS } + ); } await this.executeUpdateStatus(stageId, status, tx); @@ -304,6 +307,8 @@ export class StageManager { // update stage status if it was initialized by first task // and the stage is not already in progress + // Race condition protection: Only transition if stage is PENDING + // Multiple concurrent tasks may trigger this check simultaneously if (updatedSummary.inProgress > 0 && stage.status === StageOperationStatus.PENDING) { await this.updateStatus(stageId, StageOperationStatus.IN_PROGRESS, tx); trace.getActiveSpan()?.addEvent('Stage set to IN_PROGRESS because first task started', { stageId }); @@ -311,17 +316,31 @@ export class StageManager { } @withSpanAsyncV4 - private async executeUpdateStatus(stageId: string, status: StageOperationStatus, tx: PrismaTransaction): Promise { + private async executeUpdateStatus(stageId: string, targetStatus: StageOperationStatus, tx: PrismaTransaction): Promise { const stage = await this.getStageEntityById(stageId, { includeJob: true, tx }); if (!stage) { throw new StageNotFoundError(stagesErrorMessages.stageNotFound); } + + // Idempotent status update: if already in target status, no-op + // This prevents errors during race conditions where multiple workers + // try to set the same status (e.g., multiple tasks setting stage to IN_PROGRESS) + /* v8 ignore next 4 -- @preserve */ + if (stage.status === targetStatus) { + this.logger.debug({ + msg: 'Stage already in target status, skipping transition', + stageId, + targetStatus, + }); + return; + } + //#region validate status transition rules const previousStageOrder = stage.order - 1; // can't move to PENDING if previous stage is not COMPLETED - if (status === StageOperationStatus.PENDING && previousStageOrder > 0) { + if (targetStatus === StageOperationStatus.PENDING && previousStageOrder > 0) { const previousStage = await tx.stage.findFirst({ where: { jobId: stage.jobId, @@ -334,12 +353,12 @@ export class StageManager { } } - const nextStatusChange = OperationStatusMapper[status]; + const nextStatusChange = OperationStatusMapper[targetStatus]; const updateActor = createActor(stageStateMachine, { snapshot: stage.xstate }).start(); const isValidStatus = updateActor.getSnapshot().can({ type: nextStatusChange }); if (!isValidStatus) { - throw new IllegalStageStatusTransitionError(illegalStatusTransitionErrorMessage(stage.status, status)); + throw new IllegalStageStatusTransitionError(illegalStatusTransitionErrorMessage(stage.status, targetStatus)); } //#endregion updateActor.send({ type: nextStatusChange }); @@ -350,7 +369,7 @@ export class StageManager { id: stageId, }, data: { - status, + status: targetStatus, xstate: newPersistedSnapshot, }, }; @@ -360,7 +379,7 @@ export class StageManager { //#region update related entities // Update job completion when a stage is completed // If the stage is marked as completed, and there is a next stage in the job, update the next stage status to PENDING - if (status === StageOperationStatus.COMPLETED) { + if (targetStatus === StageOperationStatus.COMPLETED) { const nextStageOrder = stage.order + 1; const nextStage = await tx.stage.findFirst({ where: { @@ -386,11 +405,11 @@ export class StageManager { } } - if (status === StageOperationStatus.IN_PROGRESS && stage.job.status === JobOperationStatus.PENDING) { + if (targetStatus === StageOperationStatus.IN_PROGRESS && stage.job.status === JobOperationStatus.PENDING) { // Update job status to IN_PROGRESS await this.jobManager.updateStatus(stage.job.id, JobOperationStatus.IN_PROGRESS, tx); trace.getActiveSpan()?.addEvent('Job status set to IN_PROGRESS because first stage is being processed', { jobId: stage.jobId }); - } else if (status === StageOperationStatus.FAILED) { + } else if (targetStatus === StageOperationStatus.FAILED) { // Update job status to FAILED await this.jobManager.updateStatus(stage.jobId, JobOperationStatus.FAILED, tx); trace.getActiveSpan()?.addEvent('Job set to FAILED because its stage failed', { jobId: stage.jobId }); diff --git a/src/tasks/DAL/taskRepository.ts b/src/tasks/DAL/taskRepository.ts new file mode 100644 index 00000000..e313a602 --- /dev/null +++ b/src/tasks/DAL/taskRepository.ts @@ -0,0 +1,53 @@ +import { inject, Lifecycle, scoped } from 'tsyringe'; +import { type Logger } from '@map-colonies/js-logger'; +import { PrismaClient, Task } from '@prismaClient'; +import { SERVICES } from '@src/common/constants'; +import type { PrismaTransaction } from '@src/db/types'; +import type { TaskPrismaObject } from '../models/models'; + +@scoped(Lifecycle.ContainerScoped) +export class TaskRepository { + public constructor( + @inject(SERVICES.LOGGER) private readonly logger: Logger, + @inject(SERVICES.PRISMA) private readonly prisma: PrismaClient + ) {} + + /** + * Finds and locks the next available high-priority task for processing. + * * Uses a row-level lock with `SKIP LOCKED` to allow multiple concurrent + * workers to claim different tasks without blocking each other. + * * @param stageType - The stage category to pull tasks from. + * @param tx - The current database transaction. + * @returns The locked task or null if no eligible tasks are found. + */ + public async findAndLockTaskForDequeue(stageType: string, tx: PrismaTransaction): Promise { + this.logger.debug({ msg: 'Finding task for dequeue', stageType }); + + const tasks = await tx.$queryRaw` + SELECT t.* + FROM "job_manager"."task" t + INNER JOIN "job_manager"."stage" s ON t."stage_id" = s.id + INNER JOIN "job_manager"."job" j ON s."job_id" = j.id + WHERE s.type = ${stageType} + AND t.status IN ('Pending', 'Retried') + AND s.status IN ('Pending', 'In-Progress') + AND j.status IN ('Pending', 'In-Progress') + ORDER BY j.priority ASC + LIMIT 1 + FOR UPDATE OF t SKIP LOCKED + `; + + if (tasks.length === 0) { + return null; + } + + // Note: $queryRaw returns raw database values, not Prisma-mapped values + // We need to re-fetch the task using Prisma to get properly mapped enum values + const rawTask = tasks[0]!; + const task = await tx.task.findUnique({ + where: { id: rawTask.id }, + }); + + return task; + } +} diff --git a/src/tasks/models/manager.ts b/src/tasks/models/manager.ts index a4b78fa9..d147a09c 100644 --- a/src/tasks/models/manager.ts +++ b/src/tasks/models/manager.ts @@ -5,8 +5,8 @@ import { trace, type Tracer } from '@opentelemetry/api'; import { withSpanAsyncV4 } from '@map-colonies/tracing-utils'; import { subMinutes } from 'date-fns'; import { INFRA_CONVENTIONS } from '@map-colonies/semantic-conventions'; -import { JobOperationStatus, Prisma, StageOperationStatus, Task, TaskOperationStatus, type PrismaClient } from '@prismaClient'; -import { SERVICES, XSTATE_DONE_STATE } from '@common/constants'; +import { Prisma, StageOperationStatus, Task, TaskOperationStatus, type PrismaClient } from '@prismaClient'; +import { SERVICES, TX_TIMEOUT_MS, XSTATE_DONE_STATE } from '@common/constants'; import { resolveTraceContext } from '@src/common/utils/tracingHelpers'; import { StageManager } from '@src/stages/models/manager'; import { prismaKnownErrors } from '@src/common/errors'; @@ -24,59 +24,11 @@ import { TaskStatusUpdateFailedError, } from '@src/common/generated/errors'; import { ATTR_MESSAGING_DESTINATION_NAME, ATTR_MESSAGING_MESSAGE_ID } from '@src/common/semconv'; +import { TaskRepository } from '../DAL/taskRepository'; import type { TasksFindCriteriaArg, TaskModel, TaskPrismaObject, TaskCreateModel } from './models'; import { errorMessages as tasksErrorMessages } from './errors'; import { convertArrayPrismaTaskToTaskResponse, convertPrismaToTaskResponse } from './helper'; -// eslint-disable-next-line @typescript-eslint/explicit-function-return-type -function generatePrioritizedTaskQuery(stageType: string) { - // Define valid states for filtering - const validTaskStatuses = [TaskOperationStatus.PENDING, TaskOperationStatus.RETRIED]; - const validStageStatuses = [StageOperationStatus.PENDING, StageOperationStatus.IN_PROGRESS]; - const validJobStatuses = [JobOperationStatus.PENDING, JobOperationStatus.IN_PROGRESS]; - - const queryBody = { - where: { - stage: { - type: stageType, - status: { - in: validStageStatuses, - }, - job: { - status: { - in: validJobStatuses, - }, - }, - }, - status: { - in: validTaskStatuses, - }, - }, - include: { - stage: { - include: { - job: { - select: { - priority: true, - id: true, - status: true, - }, - }, - }, - }, - }, - orderBy: { - stage: { - job: { - priority: Prisma.SortOrder.asc, - }, - }, - }, - } satisfies Prisma.TaskFindFirstArgs; - - return queryBody; -} - @injectable() export class TaskManager { public constructor( @@ -84,7 +36,8 @@ export class TaskManager { @inject(SERVICES.PRISMA) private readonly prisma: PrismaClient, @inject(SERVICES.TRACER) public readonly tracer: Tracer, @inject(StageManager) private readonly stageManager: StageManager, - @inject(SERVICES.CONFIG) private readonly config: ConfigType + @inject(SERVICES.CONFIG) private readonly config: ConfigType, + @inject(TaskRepository) private readonly taskRepository: TaskRepository ) {} @withSpanAsyncV4 @@ -253,11 +206,14 @@ export class TaskManager { [INFRA_CONVENTIONS.infra.jobnik.stage.status]: status, }); - /* v8 ignore next 6 -- @preserve */ + /* v8 ignore next 8 -- @preserve */ if (!tx) { - return this.prisma.$transaction(async (newTx) => { - return this.executeUpdateStatus(taskId, status, newTx); - }); + return this.prisma.$transaction( + async (newTx) => { + return this.executeUpdateStatus(taskId, status, newTx); + }, + { timeout: TX_TIMEOUT_MS } + ); } /* v8 ignore next -- @preserve */ return this.executeUpdateStatus(taskId, status, tx); @@ -276,11 +232,14 @@ export class TaskManager { [ATTR_MESSAGING_DESTINATION_NAME]: stageType, }); - /* v8 ignore next 5 -- @preserve */ + /* v8 ignore next 7 -- @preserve */ if (tx === undefined) { - return this.prisma.$transaction(async (newTx) => { - return this.executeDequeue(stageType, newTx); - }); + return this.prisma.$transaction( + async (newTx) => { + return this.executeDequeue(stageType, newTx); + }, + { timeout: TX_TIMEOUT_MS } + ); } /* v8 ignore next -- @preserve */ @@ -365,6 +324,8 @@ export class TaskManager { /** * Executes the dequeue operation within a transaction. + * Uses SELECT FOR UPDATE to lock the task row, preventing race conditions + * when multiple workers try to dequeue simultaneously. * @param stageType - The type of stage to dequeue a task from * @param tx - The transaction object * @returns The dequeued task @@ -373,11 +334,9 @@ export class TaskManager { private async executeDequeue(stageType: string, tx: PrismaTransaction): Promise { const spanActive = trace.getActiveSpan(); - const queryBody = generatePrioritizedTaskQuery(stageType); - - const task = await tx.task.findFirst(queryBody); + const task = await this.taskRepository.findAndLockTaskForDequeue(stageType, tx); - if (task === null) { + if (!task) { throw new TaskNotFoundError(tasksErrorMessages.taskNotFound); } @@ -426,11 +385,14 @@ export class TaskManager { [INFRA_CONVENTIONS.infra.jobnik.stage.id]: task.stageId, }); - /* v8 ignore next 5 -- @preserve */ + /* v8 ignore next 7 -- @preserve */ if (!tx) { - return this.prisma.$transaction(async (newTx) => { - return this.executeUpdateAndValidateStatus(task, status, newTx); - }); + return this.prisma.$transaction( + async (newTx) => { + return this.executeUpdateAndValidateStatus(task, status, newTx); + }, + { timeout: TX_TIMEOUT_MS } + ); } return this.executeUpdateAndValidateStatus(task, status, tx); @@ -453,6 +415,15 @@ export class TaskManager { const previousStatus = task.status; const { nextStatus, taskDataToUpdate } = this.determineNextStatus(task, status); + this.logger.debug({ + msg: 'Attempting task status update', + taskId: task.id, + stageId: task.stageId, + currentStatus: previousStatus, + requestedStatus: status, + nextStatus, + }); + const newPersistedSnapshot = updateTaskMachineState(nextStatus, task.xstate); const startTime: Date | undefined = nextStatus === TaskOperationStatus.IN_PROGRESS ? new Date() : undefined; @@ -461,12 +432,21 @@ export class TaskManager { // Create update query with race condition protection for IN_PROGRESS const updateQueryBody = { - where: this.createUpdateWhereClause(task.id, nextStatus, previousStatus), + where: this.createUpdateWhereClause(task.id, previousStatus), data: { ...taskDataToUpdate, status: nextStatus, xstate: newPersistedSnapshot, startTime, endTime }, }; const updatedTasks = await tx.task.updateManyAndReturn(updateQueryBody); if (updatedTasks[0] === undefined) { + // Race condition detected: another process already modified this task + this.logger.warn({ + msg: 'Task status update failed - race condition detected', + taskId: task.id, + stageId: task.stageId, + attemptedTransition: `${previousStatus} -> ${nextStatus}`, + expectedStatus: previousStatus, + reason: 'Task status was changed by another worker before this update could complete', + }); throw new TaskStatusUpdateFailedError(tasksErrorMessages.taskStatusUpdateFailed); } @@ -528,28 +508,18 @@ export class TaskManager { } /** - * Creates the where clause for task updates, with race condition protection. - * @param taskId - The ID of the task to update - * @param nextStatus - The target status - * @param previousStatus - The current status - * @returns The where clause object for the update query + * Generates the query filter for task updates using optimistic locking. + * * By including `previousStatus` in the WHERE clause, we ensure that state + * transitions (e.g., PENDING → IN_PROGRESS) only occur if no other worker + * has modified the task in the interim. + * + * @param taskId - The ID of the task to update. + * @param nextStatus - The target status. + * @param previousStatus - The expected current status to prevent race conditions. + * @returns The filter object for the update query. */ - private createUpdateWhereClause( - taskId: string, - nextStatus: TaskOperationStatus, - previousStatus: TaskOperationStatus - ): { - id: string; - status?: TaskOperationStatus; - } { - const whereClause = { id: taskId }; - - // Add status check to prevent race conditions when setting to IN_PROGRESS - if (nextStatus === TaskOperationStatus.IN_PROGRESS) { - return { ...whereClause, status: previousStatus }; - } - - return whereClause; + private createUpdateWhereClause(taskId: string, previousStatus: TaskOperationStatus): { id: string; status: TaskOperationStatus } { + return { id: taskId, status: previousStatus }; } private async updateStageSummary( diff --git a/tests/integration/tasks/tasks.spec.ts b/tests/integration/tasks/tasks.spec.ts index 9e884ac4..ffcb2240 100644 --- a/tests/integration/tasks/tasks.spec.ts +++ b/tests/integration/tasks/tasks.spec.ts @@ -10,7 +10,7 @@ import type { paths, operations } from '@openapi'; import { JobOperationStatus, Priority, Prisma, StageOperationStatus, TaskOperationStatus, type PrismaClient } from '@prismaClient'; import type { PrismaTransaction } from '@src/db/types'; import { getApp } from '@src/app'; -import { SERVICES } from '@common/constants'; +import { SERVICES, TX_TIMEOUT_MS } from '@common/constants'; import { initConfig } from '@src/common/config'; import { errorMessages as tasksErrorMessages } from '@src/tasks/models/errors'; import { errorMessages as stagesErrorMessages } from '@src/stages/models/errors'; @@ -1455,9 +1455,7 @@ describe('task', function () { const transactionSpy = createProxyMock(prisma, '$transaction'); transactionSpy.mockImplementationOnce(async (callback: (tx: PrismaTransaction) => Promise): Promise => { const mockTx = { - task: { - findFirst: vi.fn().mockRejectedValueOnce(error), - }, + $queryRaw: vi.fn().mockRejectedValueOnce(error), } as unknown as PrismaTransaction; await callback(mockTx); @@ -1480,9 +1478,7 @@ describe('task', function () { const transactionSpy = createProxyMock(prisma, '$transaction'); transactionSpy.mockImplementationOnce(async (callback: (tx: PrismaTransaction) => Promise): Promise => { const mockTx = { - task: { - findFirst: vi.fn().mockRejectedValueOnce(error), - }, + $queryRaw: vi.fn().mockRejectedValueOnce(error), } as unknown as PrismaTransaction; await callback(mockTx); @@ -1536,7 +1532,7 @@ describe('task', function () { expect(getJobResponse.body).toHaveProperty('status', JobOperationStatus.PENDING); }); - it('should return 500 and prevent multiple dequeue of the same task', async function () { + it('should prevent multiple dequeue of the same task using database-level locking', async function () { expect.assertions(4); const initialSummary = { ...defaultStatusCounts, pending: 1, total: 1 }; @@ -1555,52 +1551,215 @@ describe('task', function () { const stageId = stage.id; const taskId = tasks[0]!.id; - let continueUpdateFirstTask: (value?: unknown) => void; - let continueUpdateSecondTask: (value?: unknown) => void; - const updateTaskHolderFirst = new Promise((resolve) => { - continueUpdateFirstTask = resolve; + // With FOR UPDATE SKIP LOCKED, concurrent dequeues are handled at database level + // The first transaction locks the row, second transaction skips it and finds no tasks + const dequeueFirstPromise = requestSender.dequeueTaskV1({ + pathParams: { stageType: 'SOME_TEST_TYPE_PREVENT_MULTIPLE_DEQUEUE' }, + }); + const dequeueSecondPromise = requestSender.dequeueTaskV1({ + pathParams: { stageType: 'SOME_TEST_TYPE_PREVENT_MULTIPLE_DEQUEUE' }, + }); + + const [firstResponse, secondResponse] = await Promise.all([dequeueFirstPromise, dequeueSecondPromise]); + + // First call will success and pull task + expect(firstResponse).toSatisfyApiSpec(); + expect(firstResponse).toMatchObject({ + status: StatusCodes.OK, + body: { + id: taskId, + status: TaskOperationStatus.IN_PROGRESS, + stageId: stageId, + }, }); - const updateTaskHolderSecond = new Promise((resolve) => { - continueUpdateSecondTask = resolve; + + // Second call will fail with 404 status code because task was locked by first transaction + expect(secondResponse).toSatisfyApiSpec(); + expect(secondResponse).toMatchObject({ + status: StatusCodes.NOT_FOUND, + body: { + message: tasksErrorMessages.taskNotFound, + code: 'TASK_NOT_FOUND', + }, }); - const original = prisma.task.findFirst.bind(prisma.task); - const spy = createProxyMock(prisma.task, 'findFirst'); - spy.mockImplementationOnce(async (...args: Parameters) => { - const res = await original(...args); - await updateTaskHolderFirst; // prevent updating the task until the second dequeue is called + }); + + it( + 'should handle concurrent dequeue and updateStatus operations with race condition protection', + { timeout: TX_TIMEOUT_MS }, + async function () { + expect.assertions(3); + + const initialSummary = { ...defaultStatusCounts, pending: 1, total: 1 }; + + const { tasks } = await createJobnikTree( + prisma, + { status: JobOperationStatus.IN_PROGRESS, xstate: inProgressStageXstatePersistentSnapshot, traceparent: DEFAULT_TRACEPARENT }, + { + status: StageOperationStatus.IN_PROGRESS, + xstate: inProgressStageXstatePersistentSnapshot, + summary: initialSummary, + type: 'SOME_TEST_TYPE_DEQUEUE_UPDATE_RACE', + }, + [{ status: TaskOperationStatus.PENDING, xstate: pendingStageXstatePersistentSnapshot }] + ); + + const taskId = tasks[0]!.id; + + // Test that dequeue and updateStatus handle race conditions properly + // Start both operations - one will succeed, the other should fail with 409 + const dequeuePromise = requestSender.dequeueTaskV1({ + pathParams: { stageType: 'SOME_TEST_TYPE_DEQUEUE_UPDATE_RACE' }, + }); + const updateStatusPromise = requestSender.updateTaskStatusV1({ + pathParams: { taskId }, + requestBody: { status: TaskOperationStatus.COMPLETED }, + }); + + const [dequeueResponse, updateStatusResponse] = await Promise.allSettled([dequeuePromise, updateStatusPromise]); + + // One should succeed, one should fail with conflict + const successCount = [dequeueResponse, updateStatusResponse].filter( + (r) => r.status === 'fulfilled' && (r.value.status as StatusCodes) === StatusCodes.OK + ).length; + const conflictCount = [dequeueResponse, updateStatusResponse].filter( + (r) => r.status === 'fulfilled' && (r.value.status as StatusCodes) === StatusCodes.CONFLICT + ).length; + + // Exactly one operation should succeed + expect(successCount).toBe(1); + // The other should get a conflict or one might complete + expect(successCount + conflictCount).toBeGreaterThanOrEqual(1); + + // Verify the task ended up in a valid state + const finalTaskResponse = await requestSender.getTaskByIdV1({ pathParams: { taskId } }); + expect(finalTaskResponse).toMatchObject({ body: { status: TaskOperationStatus.IN_PROGRESS }, status: StatusCodes.OK }); + } + ); + + it('should return 409 CONFLICT when dequeue encounters race condition during task update', async function () { + const initialSummary = { ...defaultStatusCounts, pending: 1, total: 1 }; + + await createJobnikTree( + prisma, + { status: JobOperationStatus.IN_PROGRESS, xstate: inProgressStageXstatePersistentSnapshot, traceparent: DEFAULT_TRACEPARENT }, + { + status: StageOperationStatus.IN_PROGRESS, + xstate: inProgressStageXstatePersistentSnapshot, + summary: initialSummary, + type: 'SOME_TEST_TYPE_DEQUEUE_RACE_CONFLICT', + }, + [{ status: TaskOperationStatus.PENDING, xstate: pendingStageXstatePersistentSnapshot }] + ); + + // Mock updateManyAndReturn to simulate race condition where task was modified between lock and update + const transactionSpy = createProxyMock(prisma, '$transaction'); + transactionSpy.mockImplementationOnce(async (callback: (tx: PrismaTransaction) => Promise): Promise => { + const mockTx = { + ...prisma, + $queryRaw: prisma.$queryRaw.bind(prisma), + task: { + ...prisma.task, + findUnique: prisma.task.findUnique.bind(prisma.task), + updateManyAndReturn: vi.fn().mockResolvedValue([]), // Simulate race condition - no rows updated + }, + } as unknown as PrismaTransaction; + + return callback(mockTx); + }); + + const response = await requestSender.dequeueTaskV1({ + pathParams: { stageType: 'SOME_TEST_TYPE_DEQUEUE_RACE_CONFLICT' }, + }); + + expect(response).toSatisfyApiSpec(); + expect(response).toMatchObject({ + status: StatusCodes.CONFLICT, + body: { + message: tasksErrorMessages.taskStatusUpdateFailed, + code: 'TASK_STATUS_UPDATE_FAILED', + }, + }); + }); + + it('should handle multiple concurrent updateStatus operations with race condition protection', async function () { + expect.assertions(4); + const initialSummary = { ...defaultStatusCounts, inProgress: 1, total: 1 }; + + const { tasks } = await createJobnikTree( + prisma, + { status: JobOperationStatus.IN_PROGRESS, xstate: inProgressStageXstatePersistentSnapshot, traceparent: DEFAULT_TRACEPARENT }, + { + status: StageOperationStatus.IN_PROGRESS, + xstate: inProgressStageXstatePersistentSnapshot, + summary: initialSummary, + type: 'SOME_TEST_TYPE_MULTIPLE_UPDATE_RACE', + }, + [{ status: TaskOperationStatus.IN_PROGRESS, xstate: inProgressStageXstatePersistentSnapshot }] + ); + + const taskId = tasks[0]!.id; + let continueFirstUpdate: (value?: unknown) => void; + let continueSecondUpdate: (value?: unknown) => void; + const firstUpdateHolder = new Promise((resolve) => { + continueFirstUpdate = resolve; + }); + const secondUpdateHolder = new Promise((resolve) => { + continueSecondUpdate = resolve; + }); + + // Mock task.findUnique for both update calls + const originalFindUnique = prisma.task.findUnique.bind(prisma.task); + const findUniqueSpy = createProxyMock(prisma.task, 'findUnique'); + + // First update call - pause before updating + findUniqueSpy.mockImplementationOnce(async (...args: Parameters) => { + const res = await originalFindUnique(...args); + await firstUpdateHolder; // Pause first update return res; }); - spy.mockImplementationOnce(async (...args: Parameters) => { - const res = await original(...args); - continueUpdateFirstTask(); // release the first dequeue update process - await updateTaskHolderSecond; // prevent updating the task until first dequeue release it (after his updating) + // Second update call - pause before updating + findUniqueSpy.mockImplementationOnce(async (...args: Parameters) => { + const res = await originalFindUnique(...args); + continueFirstUpdate(); // Allow first update to proceed + await secondUpdateHolder; // Pause second update return res; }); - const dequeueFirstPromise = requestSender.dequeueTaskV1({ - pathParams: { stageType: 'SOME_TEST_TYPE_PREVENT_MULTIPLE_DEQUEUE' }, + + // Start both update operations concurrently (simulating 2 workers completing the same task) + const firstUpdatePromise = requestSender.updateTaskStatusV1({ + pathParams: { taskId }, + requestBody: { status: TaskOperationStatus.COMPLETED }, }); - const dequeueSecondPromise = requestSender.dequeueTaskV1({ - pathParams: { stageType: 'SOME_TEST_TYPE_PREVENT_MULTIPLE_DEQUEUE' }, + const secondUpdatePromise = requestSender.updateTaskStatusV1({ + pathParams: { taskId }, + requestBody: { status: TaskOperationStatus.COMPLETED }, }); - const firstResponse = await dequeueFirstPromise; + + // Wait for first update to complete + const firstResponse = await firstUpdatePromise; + + // Allow second update to proceed // @ts-expect-error not recognized initialization - continueUpdateSecondTask(); //release to update second call - const secondResponse = await dequeueSecondPromise; - // first call will success and pull task + continueSecondUpdate(); + const secondResponse = await secondUpdatePromise; + + // First update should succeed - task transitioned from IN_PROGRESS to COMPLETED expect(firstResponse).toSatisfyApiSpec(); expect(firstResponse).toMatchObject({ status: StatusCodes.OK, body: { id: taskId, - status: TaskOperationStatus.IN_PROGRESS, - stageId: stageId, + status: TaskOperationStatus.COMPLETED, }, }); - //second call will fail with 500 status code due to race condition protection + + // Second update should fail because task is no longer IN_PROGRESS + // The optimistic locking prevents duplicate completion expect(secondResponse).toSatisfyApiSpec(); expect(secondResponse).toMatchObject({ - status: StatusCodes.INTERNAL_SERVER_ERROR, + status: StatusCodes.CONFLICT, body: { message: tasksErrorMessages.taskStatusUpdateFailed, code: 'TASK_STATUS_UPDATE_FAILED', diff --git a/tests/unit/jobs/jobs.spec.ts b/tests/unit/jobs/jobs.spec.ts index f7f609d1..95fff7fd 100644 --- a/tests/unit/jobs/jobs.spec.ts +++ b/tests/unit/jobs/jobs.spec.ts @@ -1,5 +1,5 @@ -import { describe, beforeEach, afterEach, it, expect, vi } from 'vitest'; -import { jsLogger } from '@map-colonies/js-logger'; +import { describe, beforeEach, afterEach, it, expect, vi, beforeAll } from 'vitest'; +import { jsLogger, Logger } from '@map-colonies/js-logger'; import { trace } from '@opentelemetry/api'; import { mockDeep, type DeepMockProxy } from 'vitest-mock-extended'; import type { PrismaClient } from '@prismaClient'; @@ -18,9 +18,15 @@ const tracer = trace.getTracer(SERVICE_NAME); const jobNotFoundError = new Prisma.PrismaClientKnownRequestError('RECORD_NOT_FOUND', { code: prismaKnownErrors.recordNotFound, clientVersion: '1' }); describe('JobManager', () => { + let logger: Logger; + + beforeAll(function () { + logger = jsLogger({ enabled: false }); + }); + beforeEach(function () { prisma = mockDeep(); - jobManager = new JobManager(jsLogger({ enabled: false }), prisma, tracer); + jobManager = new JobManager(logger, prisma, tracer); }); afterEach(function () { diff --git a/tests/unit/stages/stages.spec.ts b/tests/unit/stages/stages.spec.ts index 051d3f08..ac10923d 100644 --- a/tests/unit/stages/stages.spec.ts +++ b/tests/unit/stages/stages.spec.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/naming-convention */ -import { describe, beforeEach, afterEach, it, expect, vi } from 'vitest'; -import { jsLogger } from '@map-colonies/js-logger'; +import { describe, beforeEach, afterEach, it, expect, vi, beforeAll } from 'vitest'; +import { jsLogger, Logger } from '@map-colonies/js-logger'; import { faker } from '@faker-js/faker'; import { trace } from '@opentelemetry/api'; import { mockDeep, type DeepMockProxy } from 'vitest-mock-extended'; @@ -38,11 +38,17 @@ type StageAggregateResult = Prisma.GetStageAggregateType { + let logger: Logger; + + beforeAll(function () { + logger = jsLogger({ enabled: false }); + }); + beforeEach(function () { prisma = mockDeep(); - jobManager = new JobManager(jsLogger({ enabled: false }), prisma, tracer); - stageRepository = new StageRepository(jsLogger({ enabled: false }), prisma); - stageManager = new StageManager(jsLogger({ enabled: false }), prisma, tracer, stageRepository, jobManager); + jobManager = new JobManager(logger, prisma, tracer); + stageRepository = new StageRepository(logger, prisma); + stageManager = new StageManager(logger, prisma, tracer, stageRepository, jobManager); }); afterEach(function () { diff --git a/tests/unit/tasks/taskRepository.spec.ts b/tests/unit/tasks/taskRepository.spec.ts new file mode 100644 index 00000000..ad99eed0 --- /dev/null +++ b/tests/unit/tasks/taskRepository.spec.ts @@ -0,0 +1,134 @@ +/* eslint-disable @typescript-eslint/naming-convention */ +import { describe, beforeEach, it, expect, vi } from 'vitest'; +import { jsLogger } from '@map-colonies/js-logger'; +import { faker } from '@faker-js/faker'; +import { mockDeep, type DeepMockProxy } from 'vitest-mock-extended'; +import type { PrismaClient } from '@prismaClient'; +import { TaskOperationStatus } from '@prismaClient'; +import { TaskRepository } from '@src/tasks/DAL/taskRepository'; +import { createTaskEntity } from '../generator'; + +let taskRepository: TaskRepository; +let prisma: DeepMockProxy; + +describe('TaskRepository', () => { + beforeEach(function () { + prisma = mockDeep(); + taskRepository = new TaskRepository(jsLogger({ enabled: false }), prisma); + }); + + describe('#findAndLockTaskForDequeue', () => { + describe('#HappyPath', () => { + it('should find and lock a task for dequeue', async function () { + const stageType = 'SOME_STAGE_TYPE'; + const taskId = faker.string.uuid(); + const stageId = faker.string.uuid(); + + const rawTaskEntity = { + id: taskId, + stage_id: stageId, + status: TaskOperationStatus.PENDING, + attempts: 0, + max_attempts: 3, + data: {}, + user_metadata: {}, + xstate: {}, + creation_time: new Date(), + update_time: new Date(), + start_time: null, + end_time: null, + traceparent: null, + tracestate: null, + }; + + const taskEntity = createTaskEntity({ + id: taskId, + stageId, + status: TaskOperationStatus.PENDING, + }); + + const mockTx = { + $queryRaw: vi.fn().mockResolvedValue([rawTaskEntity]), + task: { + findUnique: vi.fn().mockResolvedValue(taskEntity), + }, + } as unknown as Parameters[1]; + + const result = await taskRepository.findAndLockTaskForDequeue(stageType, mockTx); + + expect(result).toEqual(taskEntity); + expect(mockTx.$queryRaw).toHaveBeenCalledOnce(); + }); + + it('should return null when no tasks are available', async function () { + const stageType = 'SOME_STAGE_TYPE'; + + const mockTx = { + $queryRaw: vi.fn().mockResolvedValue([]), + } as unknown as Parameters[1]; + + const result = await taskRepository.findAndLockTaskForDequeue(stageType, mockTx); + + expect(result).toBeNull(); + expect(mockTx.$queryRaw).toHaveBeenCalledOnce(); + }); + + it('should return null when task findUnique returns null', async function () { + const stageType = 'SOME_STAGE_TYPE'; + const taskId = faker.string.uuid(); + + const rawTaskEntity = { + id: taskId, + stage_id: faker.string.uuid(), + status: TaskOperationStatus.PENDING, + }; + + const mockTx = { + $queryRaw: vi.fn().mockResolvedValue([rawTaskEntity]), + task: { + findUnique: vi.fn().mockResolvedValue(null), + }, + } as unknown as Parameters[1]; + + const result = await taskRepository.findAndLockTaskForDequeue(stageType, mockTx); + + expect(result).toBeNull(); + }); + }); + + describe('#SadPath', () => { + it('should throw error when database query fails', async function () { + const stageType = 'SOME_STAGE_TYPE'; + const error = new Error('Database connection error'); + + const mockTx = { + $queryRaw: vi.fn().mockRejectedValue(error), + } as unknown as Parameters[1]; + + await expect(taskRepository.findAndLockTaskForDequeue(stageType, mockTx)).rejects.toThrow('Database connection error'); + }); + + it('should throw error when findUnique fails', async function () { + const stageType = 'SOME_STAGE_TYPE'; + const taskId = faker.string.uuid(); + + const rawTaskEntity = { + id: taskId, + stage_id: faker.string.uuid(), + status: TaskOperationStatus.PENDING, + }; + + const error = new Error('Database connection error'); + + const mockTx = { + $queryRaw: vi.fn().mockResolvedValue([rawTaskEntity]), + task: { + findUnique: vi.fn().mockRejectedValue(error), + }, + } as unknown as Parameters[1]; + + await expect(taskRepository.findAndLockTaskForDequeue(stageType, mockTx)).rejects.toThrow('Database connection error'); + }); + }); + }); +}); diff --git a/tests/unit/tasks/tasks.spec.ts b/tests/unit/tasks/tasks.spec.ts index 5fb73552..7bcb7274 100644 --- a/tests/unit/tasks/tasks.spec.ts +++ b/tests/unit/tasks/tasks.spec.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/naming-convention */ import { describe, beforeEach, afterEach, it, expect, beforeAll, vi } from 'vitest'; -import { jsLogger } from '@map-colonies/js-logger'; +import { jsLogger, type Logger } from '@map-colonies/js-logger'; import { faker } from '@faker-js/faker'; import { trace } from '@opentelemetry/api'; import { subHours, subMinutes } from 'date-fns'; @@ -15,6 +15,7 @@ import { TaskManager } from '@src/tasks/models/manager'; import { prismaKnownErrors } from '@src/common/errors'; import { TaskCreateModel } from '@src/tasks/models/models'; import { StageRepository } from '@src/stages/DAL/stageRepository'; +import { TaskRepository } from '@src/tasks/DAL/taskRepository'; import { SERVICE_NAME } from '@src/common/constants'; import { IllegalTaskStatusTransitionError, NotAllowedToAddTasksToInProgressStageError, StageInFiniteStateError } from '@src/common/generated/errors'; import { getConfig, initConfig } from '@src/common/config'; @@ -41,6 +42,7 @@ let jobManager: JobManager; let stageManager: StageManager; let taskManager: TaskManager; let stageRepository: StageRepository; +let taskRepository: TaskRepository; let prisma: DeepMockProxy; const tracer = trace.getTracer(SERVICE_NAME); @@ -50,17 +52,21 @@ let config: ReturnType; const notFoundError = new Prisma.PrismaClientKnownRequestError('RECORD_NOT_FOUND', { code: prismaKnownErrors.recordNotFound, clientVersion: '1' }); describe('JobManager', () => { + let logger: Logger; + beforeAll(async function () { + logger = jsLogger({ enabled: false }); await initConfig(true); }); beforeEach(function () { config = getConfig(); prisma = mockDeep(); - jobManager = new JobManager(jsLogger({ enabled: false }), prisma, tracer); - stageRepository = new StageRepository(jsLogger({ enabled: false }), prisma); - stageManager = new StageManager(jsLogger({ enabled: false }), prisma, tracer, stageRepository, jobManager); - taskManager = new TaskManager(jsLogger({ enabled: false }), prisma, tracer, stageManager, config); + jobManager = new JobManager(logger, prisma, tracer); + stageRepository = new StageRepository(logger, prisma); + taskRepository = new TaskRepository(logger, prisma); + stageManager = new StageManager(logger, prisma, tracer, stageRepository, jobManager); + taskManager = new TaskManager(logger, prisma, tracer, stageManager, config, taskRepository); }); afterEach(function () { @@ -586,10 +592,12 @@ describe('JobManager', () => { xstate: pendingStageXstatePersistentSnapshot, }); + vi.spyOn(taskRepository, 'findAndLockTaskForDequeue').mockResolvedValue(taskEntity); + prisma.$transaction.mockImplementationOnce(async (callback) => { const mockTx = { task: { - findFirst: vi.fn().mockResolvedValue(taskEntity), + findUnique: vi.fn().mockResolvedValue(taskEntity), updateManyAndReturn: vi.fn().mockResolvedValue([taskEntity]), }, stage: { @@ -608,12 +616,10 @@ describe('JobManager', () => { describe('#BadPath', () => { it('should get code 404 not found for no available tasks to dequeue', async function () { + vi.spyOn(taskRepository, 'findAndLockTaskForDequeue').mockResolvedValue(null); + prisma.$transaction.mockImplementationOnce(async (callback) => { - const mockTx = { - task: { - findFirst: vi.fn().mockResolvedValue(null), - }, - } as unknown as Omit; + const mockTx = {} as unknown as Omit; return callback(mockTx); }); @@ -624,12 +630,10 @@ describe('JobManager', () => { describe('#SadPath', () => { it('should fail with a database error when adding tasks', async function () { + vi.spyOn(taskRepository, 'findAndLockTaskForDequeue').mockRejectedValue(new Error('db connection error')); + prisma.$transaction.mockImplementationOnce(async (callback) => { - const mockTx = { - task: { - findFirst: vi.fn().mockRejectedValue(new Error('db connection error')), - }, - } as unknown as Omit; + const mockTx = {} as unknown as Omit; return callback(mockTx); }); @@ -651,10 +655,12 @@ describe('JobManager', () => { xstate: pendingStageXstatePersistentSnapshot, }); + vi.spyOn(taskRepository, 'findAndLockTaskForDequeue').mockResolvedValue(taskEntity); + prisma.$transaction.mockImplementationOnce(async (callback) => { const mockTx = { task: { - findFirst: vi.fn().mockResolvedValue(taskEntity), + findUnique: vi.fn().mockResolvedValue(taskEntity), updateManyAndReturn: vi.fn().mockResolvedValue([]), }, } as unknown as Omit; diff --git a/vitest.config.mts b/vitest.config.mts index b74ddba0..64eccc5a 100644 --- a/vitest.config.mts +++ b/vitest.config.mts @@ -28,6 +28,11 @@ export default defineConfig({ setupFiles: ['./tests/configurations/initJestOpenapi.setup.ts', './tests/configurations/vite.setup.ts'], include: ['tests/unit/**/*.spec.ts'], environment: 'node', + server: { + deps: { + external: ['node-cron'], + }, + }, }, resolve: { alias: pathAlias, @@ -40,6 +45,11 @@ export default defineConfig({ setupFiles: ['./tests/configurations/initJestOpenapi.setup.ts', './tests/configurations/vite.setup.ts'], include: ['tests/integration/**/*.spec.ts'], environment: 'node', + server: { + deps: { + external: ['node-cron'], + }, + }, }, resolve: { alias: pathAlias,