diff --git a/.changeset/fifty-pigs-sit.md b/.changeset/fifty-pigs-sit.md new file mode 100644 index 0000000000000..eee4d5aae730d --- /dev/null +++ b/.changeset/fifty-pigs-sit.md @@ -0,0 +1,9 @@ +--- +"@medusajs/workflow-engine-inmemory": patch +"@medusajs/workflow-engine-redis": patch +"@medusajs/orchestration": patch +"@medusajs/workflows-sdk": patch +"@medusajs/utils": patch +--- + +fix: workflow retry interval race condition diff --git a/integration-tests/modules/__tests__/workflow-engine/tests.ts b/integration-tests/modules/__tests__/workflow-engine/tests.ts index 8855260d15131..12e3c872756ae 100644 --- a/integration-tests/modules/__tests__/workflow-engine/tests.ts +++ b/integration-tests/modules/__tests__/workflow-engine/tests.ts @@ -185,7 +185,7 @@ export const workflowEngineTestSuite = ( hasAsyncSteps: true, hasFailedSteps: false, hasSkippedSteps: false, - hasWaitingSteps: false, + hasWaitingSteps: true, hasRevertedSteps: false, }), context: expect.objectContaining({ @@ -236,6 +236,13 @@ export const workflowEngineTestSuite = ( workflow_id: "my-workflow-name", transaction_id: "trx_123", state: "done", + execution: expect.objectContaining({ + hasAsyncSteps: true, + hasFailedSteps: false, + hasSkippedSteps: false, + hasWaitingSteps: false, + hasRevertedSteps: false, + }), context: expect.objectContaining({ data: expect.objectContaining({ invoke: expect.objectContaining({ diff --git a/packages/core/orchestration/src/transaction/distributed-transaction.ts b/packages/core/orchestration/src/transaction/distributed-transaction.ts index afa5d280eb2f7..e312aff39f3e7 100644 --- a/packages/core/orchestration/src/transaction/distributed-transaction.ts +++ b/packages/core/orchestration/src/transaction/distributed-transaction.ts @@ -2,6 +2,7 @@ import { isDefined } from "@medusajs/utils" import { EventEmitter } from "events" import { IDistributedTransactionStorage } from "./datastore/abstract-storage" import { BaseInMemoryDistributedTransactionStorage } from "./datastore/base-in-memory-storage" +import { NonSerializableCheckPointError } from "./errors" import { TransactionOrchestrator } from "./transaction-orchestrator" import { TransactionStep, TransactionStepHandler } from "./transaction-step" import { @@ -9,7 +10,6 @@ import { TransactionHandlerType, TransactionState, } from "./types" -import { NonSerializableCheckPointError } from "./errors" /** * @typedef TransactionMetadata @@ -229,6 +229,7 @@ class DistributedTransaction extends EventEmitter { ) const options = TransactionOrchestrator.getWorkflowOptions(modelId) + const loadedData = await DistributedTransaction.keyValueStore.get( key, options @@ -248,7 +249,6 @@ class DistributedTransaction extends EventEmitter { return } - await this.saveCheckpoint() await DistributedTransaction.keyValueStore.scheduleRetry( this, step, @@ -267,7 +267,6 @@ class DistributedTransaction extends EventEmitter { return } - await this.saveCheckpoint() await DistributedTransaction.keyValueStore.scheduleTransactionTimeout( this, Date.now(), diff --git a/packages/core/orchestration/src/transaction/errors.ts b/packages/core/orchestration/src/transaction/errors.ts index 8a3ff7bb7f265..081f067db5eae 100644 --- a/packages/core/orchestration/src/transaction/errors.ts +++ b/packages/core/orchestration/src/transaction/errors.ts @@ -84,3 +84,12 @@ export class NonSerializableCheckPointError extends Error { this.name = "NonSerializableCheckPointError" } } + +export class SkipExecutionError extends Error { + static isSkipExecutionError(error: Error): error is SkipExecutionError { + return ( + error instanceof SkipExecutionError || + error?.name === "SkipExecutionError" + ) + } +} diff --git a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts index 7bb4ebd241cfb..a2d70e9889e4d 100644 --- a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts +++ b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts @@ -28,6 +28,7 @@ import { import { EventEmitter } from "events" import { PermanentStepFailureError, + SkipExecutionError, SkipStepResponse, TransactionStepTimeoutError, TransactionTimeoutError, @@ -54,7 +55,7 @@ export class TransactionOrchestrator extends EventEmitter { } = {} public static getWorkflowOptions(modelId: string): TransactionOptions { - return this.workflowOptions[modelId] + return TransactionOrchestrator.workflowOptions[modelId] } /** @@ -239,6 +240,7 @@ export class TransactionOrchestrator extends EventEmitter { ) { const flow = transaction.getFlow() let hasTimedOut = false + if (!flow.timedOutAt && this.hasExpired({ transaction }, Date.now())) { flow.timedOutAt = Date.now() @@ -252,8 +254,6 @@ export class TransactionOrchestrator extends EventEmitter { ) } - await transaction.saveCheckpoint() - this.emit(DistributedTransactionEvent.TIMEOUT, { transaction }) hasTimedOut = true @@ -281,8 +281,6 @@ export class TransactionOrchestrator extends EventEmitter { ) hasTimedOut = true - await transaction.saveCheckpoint() - this.emit(DistributedTransactionEvent.TIMEOUT, { transaction }) } return hasTimedOut @@ -457,7 +455,9 @@ export class TransactionOrchestrator extends EventEmitter { transaction: DistributedTransactionType, step: TransactionStep, response: unknown - ): Promise { + ): Promise<{ + stopExecution: boolean + }> { const hasStepTimedOut = step.getStates().state === TransactionStepState.TIMEOUT @@ -471,9 +471,6 @@ export class TransactionOrchestrator extends EventEmitter { ) } - const flow = transaction.getFlow() - const options = TransactionOrchestrator.getWorkflowOptions(flow.modelId) - if (!hasStepTimedOut) { step.changeStatus(TransactionStepStatus.OK) } @@ -484,8 +481,11 @@ export class TransactionOrchestrator extends EventEmitter { step.changeState(TransactionStepState.DONE) } - if (step.definition.async || options?.storeExecution) { + let shouldEmit = true + try { await transaction.saveCheckpoint() + } catch (error) { + shouldEmit = false } const cleaningUp: Promise[] = [] @@ -498,29 +498,41 @@ export class TransactionOrchestrator extends EventEmitter { await promiseAll(cleaningUp) - const eventName = step.isCompensating() - ? DistributedTransactionEvent.COMPENSATE_STEP_SUCCESS - : DistributedTransactionEvent.STEP_SUCCESS - transaction.emit(eventName, { step, transaction }) + if (shouldEmit) { + const eventName = step.isCompensating() + ? DistributedTransactionEvent.COMPENSATE_STEP_SUCCESS + : DistributedTransactionEvent.STEP_SUCCESS + transaction.emit(eventName, { step, transaction }) + } + + return { + stopExecution: !shouldEmit, + } } private static async skipStep( transaction: DistributedTransactionType, step: TransactionStep - ): Promise { + ): Promise<{ + stopExecution: boolean + }> { const hasStepTimedOut = step.getStates().state === TransactionStepState.TIMEOUT - const flow = transaction.getFlow() - const options = TransactionOrchestrator.getWorkflowOptions(flow.modelId) - if (!hasStepTimedOut) { step.changeStatus(TransactionStepStatus.OK) step.changeState(TransactionStepState.SKIPPED) } - if (step.definition.async || options?.storeExecution) { + let shouldEmit = true + try { await transaction.saveCheckpoint() + } catch (error) { + if (SkipExecutionError.isSkipExecutionError(error)) { + shouldEmit = false + } else { + throw error + } } const cleaningUp: Promise[] = [] @@ -533,8 +545,14 @@ export class TransactionOrchestrator extends EventEmitter { await promiseAll(cleaningUp) - const eventName = DistributedTransactionEvent.STEP_SKIPPED - transaction.emit(eventName, { step, transaction }) + if (shouldEmit) { + const eventName = DistributedTransactionEvent.STEP_SKIPPED + transaction.emit(eventName, { step, transaction }) + } + + return { + stopExecution: !shouldEmit, + } } private static async setStepTimeout( @@ -589,7 +607,15 @@ export class TransactionOrchestrator extends EventEmitter { maxRetries: number = TransactionOrchestrator.DEFAULT_RETRIES, isTimeout = false, timeoutError?: TransactionStepTimeoutError | TransactionTimeoutError - ): Promise { + ): Promise<{ + stopExecution: boolean + }> { + if (SkipExecutionError.isSkipExecutionError(error)) { + return { + stopExecution: false, + } + } + step.failures++ if (isErrorLike(error)) { @@ -604,7 +630,6 @@ export class TransactionOrchestrator extends EventEmitter { } const flow = transaction.getFlow() - const options = TransactionOrchestrator.getWorkflowOptions(flow.modelId) const cleaningUp: Promise[] = [] @@ -653,8 +678,15 @@ export class TransactionOrchestrator extends EventEmitter { } } - if (step.definition.async || options?.storeExecution) { + let shouldEmit = true + try { await transaction.saveCheckpoint() + } catch (error) { + if (SkipExecutionError.isSkipExecutionError(error)) { + shouldEmit = false + } else { + throw error + } } if (step.hasRetryScheduled()) { @@ -663,10 +695,16 @@ export class TransactionOrchestrator extends EventEmitter { await promiseAll(cleaningUp) - const eventName = step.isCompensating() - ? DistributedTransactionEvent.COMPENSATE_STEP_FAILURE - : DistributedTransactionEvent.STEP_FAILURE - transaction.emit(eventName, { step, transaction }) + if (shouldEmit) { + const eventName = step.isCompensating() + ? DistributedTransactionEvent.COMPENSATE_STEP_FAILURE + : DistributedTransactionEvent.STEP_FAILURE + transaction.emit(eventName, { step, transaction }) + } + + return { + stopExecution: !shouldEmit, + } } private async executeNext( @@ -680,7 +718,6 @@ export class TransactionOrchestrator extends EventEmitter { } const flow = transaction.getFlow() - const options = TransactionOrchestrator.getWorkflowOptions(flow.modelId) const nextSteps = await this.checkAllSteps(transaction) const execution: Promise[] = [] @@ -699,11 +736,9 @@ export class TransactionOrchestrator extends EventEmitter { } await transaction.saveCheckpoint() - this.emit(DistributedTransactionEvent.FINISH, { transaction }) } - let hasSyncSteps = false for (const step of nextSteps.next) { const curState = step.getStates() const type = step.isCompensating() @@ -783,19 +818,18 @@ export class TransactionOrchestrator extends EventEmitter { ) } - await TransactionOrchestrator.setStepFailure( + const ret = await TransactionOrchestrator.setStepFailure( transaction, step, error, endRetry ? 0 : step.definition.maxRetries ) - if (isAsync) { - await transaction.scheduleRetry( - step, - step.definition.retryInterval ?? 0 - ) + if (isAsync && !ret.stopExecution) { + await transaction.scheduleRetry(step, 0) } + + return ret } const traceData = { @@ -821,8 +855,6 @@ export class TransactionOrchestrator extends EventEmitter { ] as Parameters if (!isAsync) { - hasSyncSteps = true - const stepHandler = async () => { return await transaction.handler(...handlerArgs) } @@ -875,10 +907,13 @@ export class TransactionOrchestrator extends EventEmitter { endRetry: true, response, }) + return } - await setStepFailure(error, { response }) + await setStepFailure(error, { + response, + }) }) ) } else { @@ -933,10 +968,7 @@ export class TransactionOrchestrator extends EventEmitter { } // check nested flow - await transaction.scheduleRetry( - step, - step.definition.retryInterval ?? 0 - ) + await transaction.scheduleRetry(step, 0) }) .catch(async (error) => { const response = error?.getStepResponse?.() @@ -948,18 +980,27 @@ export class TransactionOrchestrator extends EventEmitter { endRetry: true, response, }) + return } - await setStepFailure(error, { response }) + await setStepFailure(error, { + response, + }) }) }) ) } } - if (hasSyncSteps && options?.storeExecution) { + try { await transaction.saveCheckpoint() + } catch (error) { + if (SkipExecutionError.isSkipExecutionError(error)) { + break + } else { + throw error + } } await promiseAll(execution) @@ -993,11 +1034,9 @@ export class TransactionOrchestrator extends EventEmitter { flow.state = TransactionState.INVOKING flow.startedAt = Date.now() - if (this.getOptions().store) { - await transaction.saveCheckpoint( - flow.hasAsyncSteps ? 0 : TransactionOrchestrator.DEFAULT_TTL - ) - } + await transaction.saveCheckpoint( + flow.hasAsyncSteps ? 0 : TransactionOrchestrator.DEFAULT_TTL + ) if (transaction.hasTimeout()) { await transaction.scheduleTransactionTimeout( @@ -1079,7 +1118,6 @@ export class TransactionOrchestrator extends EventEmitter { isIdempotent ) { this.options.store = true - this.options.storeExecution = true } const parsedOptions = { @@ -1272,11 +1310,7 @@ export class TransactionOrchestrator extends EventEmitter { existingTransaction?.context ) - if ( - newTransaction && - this.getOptions().store && - this.getOptions().storeExecution - ) { + if (newTransaction && this.getOptions().store) { await transaction.saveCheckpoint( modelFlow.hasAsyncSteps ? 0 : TransactionOrchestrator.DEFAULT_TTL ) diff --git a/packages/core/orchestration/src/transaction/types.ts b/packages/core/orchestration/src/transaction/types.ts index 11e04aa370613..66c49b5434ebf 100644 --- a/packages/core/orchestration/src/transaction/types.ts +++ b/packages/core/orchestration/src/transaction/types.ts @@ -111,20 +111,21 @@ export type TransactionModelOptions = { /** * If true, the state of the transaction will be persisted. - * + * * Learn more in [this documentation](https://docs.medusajs.com/learn/fundamentals/workflows/store-executions). */ store?: boolean /** * The number of seconds that the workflow execution should be stored in the database. - * + * * Learn more in [this documentation](https://docs.medusajs.com/learn/fundamentals/workflows/store-executions). */ retentionTime?: number /** * If true, the execution details of each step will be stored. + * @deprecated no longer needed. */ storeExecution?: boolean diff --git a/packages/core/utils/src/modules-sdk/medusa-internal-service.ts b/packages/core/utils/src/modules-sdk/medusa-internal-service.ts index 0a3e64f81a66e..4c6be7e88e0c5 100644 --- a/packages/core/utils/src/modules-sdk/medusa-internal-service.ts +++ b/packages/core/utils/src/modules-sdk/medusa-internal-service.ts @@ -137,9 +137,11 @@ export function MedusaInternalService< const idOrObject_ = Array.isArray(idOrObject) ? idOrObject : [idOrObject] - primaryKeysCriteria = idOrObject_.map((primaryKeyValue) => ({ - $and: primaryKeys.map((key) => ({ [key]: primaryKeyValue[key] })), - })) + primaryKeysCriteria = { + $or: idOrObject_.map((primaryKeyValue) => ({ + $and: primaryKeys.map((key) => ({ [key]: primaryKeyValue[key] })), + })), + } } const queryOptions = buildQuery(primaryKeysCriteria, config) @@ -157,6 +159,8 @@ export function MedusaInternalService< ? idOrObject.map((v) => [isString(v) ? v : Object.values(v)].join(", ") ) + : isObject(idOrObject) + ? Object.values(idOrObject).join(", ") : idOrObject } was not found` ) diff --git a/packages/core/workflows-sdk/src/helper/type.ts b/packages/core/workflows-sdk/src/helper/type.ts index 1dc21d97d8694..8de3f53638a7e 100644 --- a/packages/core/workflows-sdk/src/helper/type.ts +++ b/packages/core/workflows-sdk/src/helper/type.ts @@ -31,9 +31,14 @@ export type FlowRegisterStepFailureOptions = response?: TData } -export type FlowCancelOptions = BaseFlowRunOptions & { +export type FlowCancelOptions = { transaction?: DistributedTransactionType transactionId?: string + context?: Context + throwOnError?: boolean + logOnError?: boolean + events?: DistributedTransactionEvents + container?: LoadedModule[] | MedusaContainer } /** diff --git a/packages/core/workflows-sdk/src/utils/composer/create-step.ts b/packages/core/workflows-sdk/src/utils/composer/create-step.ts index bdf58d6c5f463..9fa20f23641ba 100644 --- a/packages/core/workflows-sdk/src/utils/composer/create-step.ts +++ b/packages/core/workflows-sdk/src/utils/composer/create-step.ts @@ -192,8 +192,6 @@ export function applyStep< ret.__step__ = newStepName WorkflowManager.update(this.workflowId, this.flow, this.handlers) - //const confRef = proxify(ret) - if (global[OrchestrationUtils.SymbolMedusaWorkflowComposerCondition]) { const flagSteps = global[OrchestrationUtils.SymbolMedusaWorkflowComposerCondition].steps @@ -334,11 +332,11 @@ function wrapConditionalStep( * createStep, * StepResponse * } from "@medusajs/framework/workflows-sdk" - * + * * interface CreateProductInput { * title: string * } - * + * * export const createProductStep = createStep( * "createProductStep", * async function ( diff --git a/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts b/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts index 58f5134ed3b24..18d9eecf8d748 100644 --- a/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts +++ b/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts @@ -47,22 +47,22 @@ global[OrchestrationUtils.SymbolMedusaWorkflowComposerContext] = null * createProductStep, * getProductStep, * } from "./steps" - * + * * interface WorkflowInput { * title: string * } - * + * * const myWorkflow = createWorkflow( * "my-workflow", * (input: WorkflowInput) => { * // Everything here will be executed and resolved later * // during the execution. Including the data access. - * + * * const product = createProductStep(input) * return new WorkflowResponse(getProductStep(product.id)) * } * ) - * + * * export async function GET( * req: MedusaRequest, * res: MedusaResponse @@ -73,7 +73,7 @@ global[OrchestrationUtils.SymbolMedusaWorkflowComposerContext] = null * title: "Shirt" * } * }) - * + * * res.json({ * product * }) diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts index e52ed640d76d1..bee8c94d41108 100644 --- a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts @@ -30,7 +30,7 @@ import { } from "../__fixtures__/workflow_event_group_id" import { createScheduled } from "../__fixtures__/workflow_scheduled" -jest.setTimeout(100000) +jest.setTimeout(3000000) moduleIntegrationTestRunner({ moduleName: Modules.WORKFLOW_ENGINE, diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/race.spec.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/race.spec.ts new file mode 100644 index 0000000000000..8da509ad8a62c --- /dev/null +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/race.spec.ts @@ -0,0 +1,182 @@ +import { IWorkflowEngineService } from "@medusajs/framework/types" +import { Modules } from "@medusajs/framework/utils" +import { + createStep, + createWorkflow, + StepResponse, + transform, + WorkflowResponse, +} from "@medusajs/framework/workflows-sdk" +import { moduleIntegrationTestRunner } from "@medusajs/test-utils" +import { setTimeout as setTimeoutSync } from "timers" +import { setTimeout } from "timers/promises" +import "../__fixtures__" + +jest.setTimeout(3000000) + +moduleIntegrationTestRunner({ + moduleName: Modules.WORKFLOW_ENGINE, + resolve: __dirname + "/../..", + testSuite: ({ service: workflowOrcModule, medusaApp }) => { + describe("Testing race condition of the workflow during retry", () => { + it("should prevent race continuation of the workflow during retryIntervalAwaiting in background execution", (done) => { + const step0InvokeMock = jest.fn() + const step1InvokeMock = jest.fn() + const step2InvokeMock = jest.fn() + const transformMock = jest.fn() + + const step0 = createStep("step0", async (_) => { + step0InvokeMock() + return new StepResponse("result from step 0") + }) + + const step1 = createStep("step1", async (_) => { + step1InvokeMock() + await setTimeout(2000) + return new StepResponse({ isSuccess: true }) + }) + + const step2 = createStep("step2", async (input: any) => { + step2InvokeMock() + return new StepResponse({ result: input }) + }) + + const subWorkflow = createWorkflow("sub-workflow-1", function () { + const status = step1() + return new WorkflowResponse(status) + }) + + createWorkflow("workflow-1", function () { + const build = step0() + + const status = subWorkflow.runAsStep({} as any).config({ + async: true, + compensateAsync: true, + backgroundExecution: true, + retryIntervalAwaiting: 1, + }) + + const transformedResult = transform({ status }, (data) => { + transformMock() + return { + status: data.status, + } + }) + + step2(transformedResult) + return new WorkflowResponse(build) + }) + + void workflowOrcModule.subscribe({ + workflowId: "workflow-1", + + subscriber: (event) => { + if (event.eventType === "onFinish") { + expect(step0InvokeMock).toHaveBeenCalledTimes(1) + expect(step1InvokeMock.mock.calls.length).toBeGreaterThan(1) + expect(step2InvokeMock).toHaveBeenCalledTimes(1) + expect(transformMock).toHaveBeenCalledTimes(1) + setTimeoutSync(done, 500) + } + }, + }) + + workflowOrcModule + .run("workflow-1", { throwOnError: false }) + .then(({ result }) => { + expect(result).toBe("result from step 0") + }) + .catch((e) => e) + }) + + it("should prevent race continuation of the workflow compensation during retryIntervalAwaiting in background execution", (done) => { + const workflowId = "RACE_workflow-1" + + const step0InvokeMock = jest.fn() + const step0CompensateMock = jest.fn() + const step1InvokeMock = jest.fn() + const step1CompensateMock = jest.fn() + const step2InvokeMock = jest.fn() + const transformMock = jest.fn() + + const step0 = createStep( + "RACE_step0", + async (_) => { + step0InvokeMock() + return new StepResponse("result from step 0") + }, + () => { + step0CompensateMock() + } + ) + + const step1 = createStep( + "RACE_step1", + async (_) => { + step1InvokeMock() + await setTimeout(300) + throw new Error("error from step 1") + }, + () => { + step1CompensateMock() + } + ) + + const step2 = createStep("RACE_step2", async (input: any) => { + step2InvokeMock() + return new StepResponse({ result: input }) + }) + + const subWorkflow = createWorkflow("RACE_sub-workflow-1", function () { + const status = step1() + return new WorkflowResponse(status) + }) + + createWorkflow(workflowId, function () { + const build = step0() + + const status = subWorkflow.runAsStep({} as any).config({ + async: true, + compensateAsync: true, + backgroundExecution: true, + retryIntervalAwaiting: 0.1, + }) + + const transformedResult = transform({ status }, (data) => { + transformMock() + return { + status: data.status, + } + }) + + step2(transformedResult) + return new WorkflowResponse(build) + }) + + void workflowOrcModule.subscribe({ + workflowId: workflowId, + subscriber: async (event) => { + if (event.eventType === "onFinish") { + expect(step0InvokeMock).toHaveBeenCalledTimes(1) + expect(step0CompensateMock).toHaveBeenCalledTimes(1) + expect(step1InvokeMock.mock.calls.length).toBeGreaterThan(2) + expect(step1CompensateMock).toHaveBeenCalledTimes(1) + expect(step2InvokeMock).toHaveBeenCalledTimes(0) + expect(transformMock).toHaveBeenCalledTimes(0) + setTimeoutSync(done, 500) + } + }, + }) + + workflowOrcModule + .run(workflowId, { + throwOnError: false, + }) + .then(({ result }) => { + expect(result).toBe("result from step 0") + }) + .catch((e) => e) + }) + }) + }, +}) diff --git a/packages/modules/workflow-engine-inmemory/package.json b/packages/modules/workflow-engine-inmemory/package.json index 3a4039af29f1b..ecaf2fb50bbaf 100644 --- a/packages/modules/workflow-engine-inmemory/package.json +++ b/packages/modules/workflow-engine-inmemory/package.json @@ -29,7 +29,7 @@ "resolve:aliases": "tsc --showConfig -p tsconfig.json > tsconfig.resolved.json && tsc-alias -p tsconfig.resolved.json && rimraf tsconfig.resolved.json", "build": "rimraf dist && tsc --build && npm run resolve:aliases", "test": "jest --passWithNoTests --runInBand --bail --forceExit -- src", - "test:integration": "jest --silent --forceExit -- integration-tests/**/__tests__/**/*.ts", + "test:integration": "jest --forceExit -- integration-tests/**/__tests__/**/*.ts", "migration:initial": " MIKRO_ORM_CLI_CONFIG=./mikro-orm.config.dev.ts medusa-mikro-orm migration:create --initial", "migration:create": " MIKRO_ORM_CLI_CONFIG=./mikro-orm.config.dev.ts medusa-mikro-orm migration:create", "migration:up": " MIKRO_ORM_CLI_CONFIG=./mikro-orm.config.dev.ts medusa-mikro-orm migration:up", diff --git a/packages/modules/workflow-engine-inmemory/src/services/workflow-orchestrator.ts b/packages/modules/workflow-engine-inmemory/src/services/workflow-orchestrator.ts index 1d25d222d4564..02586d8a6eeea 100644 --- a/packages/modules/workflow-engine-inmemory/src/services/workflow-orchestrator.ts +++ b/packages/modules/workflow-engine-inmemory/src/services/workflow-orchestrator.ts @@ -6,15 +6,9 @@ import { TransactionStep, WorkflowScheduler, } from "@medusajs/framework/orchestration" +import { ContainerLike, MedusaContainer } from "@medusajs/framework/types" import { - ContainerLike, - Context, - MedusaContainer, -} from "@medusajs/framework/types" -import { - InjectSharedContext, isString, - MedusaContext, MedusaError, TransactionState, } from "@medusajs/framework/utils" @@ -26,6 +20,7 @@ import { } from "@medusajs/framework/workflows-sdk" import { ulid } from "ulid" import { InMemoryDistributedTransactionStorage } from "../utils" +import { WorkflowOrchestratorCancelOptions } from "@types" export type WorkflowOrchestratorRunOptions = Omit< FlowRunOptions, @@ -125,11 +120,9 @@ export class WorkflowOrchestratorService { } } - @InjectSharedContext() async run( workflowIdOrWorkflow: string | ReturnWorkflow, - options?: WorkflowOrchestratorRunOptions, - @MedusaContext() sharedContext: Context = {} + options?: WorkflowOrchestratorRunOptions ) { const { input, @@ -217,12 +210,116 @@ export class WorkflowOrchestratorService { return { acknowledgement, ...ret } } - @InjectSharedContext() + async cancel( + workflowIdOrWorkflow: string | ReturnWorkflow, + options?: WorkflowOrchestratorCancelOptions + ) { + const { + transactionId, + logOnError, + events: eventHandlers, + container, + } = options ?? {} + + let { throwOnError, context } = options ?? {} + + throwOnError ??= true + context ??= {} + + const workflowId = isString(workflowIdOrWorkflow) + ? workflowIdOrWorkflow + : workflowIdOrWorkflow.getName() + + if (!workflowId) { + throw new Error("Workflow ID is required") + } + + if (!transactionId) { + throw new Error("Transaction ID is required") + } + + const events: FlowRunOptions["events"] = this.buildWorkflowEvents({ + customEventHandlers: eventHandlers, + workflowId, + transactionId: transactionId, + }) + + const exportedWorkflow = MedusaWorkflow.getWorkflow(workflowId) + if (!exportedWorkflow) { + throw new Error(`Workflow with id "${workflowId}" not found.`) + } + + const originalOnFinishHandler = events.onFinish! + delete events.onFinish + + const transaction = await this.getRunningTransaction( + workflowId, + transactionId, + options + ) + + if (!transaction) { + if (!throwOnError) { + return { + acknowledgement: { + transactionId, + workflowId, + exists: false, + }, + } + } + throw new Error("Transaction not found") + } + + const ret = await exportedWorkflow.cancel({ + transaction, + throwOnError: false, + logOnError, + context, + events, + container: container ?? this.container_, + }) + + const hasFinished = ret.transaction.hasFinished() + const metadata = ret.transaction.getFlow().metadata + const { parentStepIdempotencyKey } = metadata ?? {} + + const hasFailed = [TransactionState.FAILED].includes( + ret.transaction.getFlow().state + ) + + const acknowledgement = { + transactionId: context.transactionId, + workflowId: workflowId, + parentStepIdempotencyKey, + hasFinished, + hasFailed, + exists: true, + } + + if (hasFinished) { + const { result, errors } = ret + + await originalOnFinishHandler({ + transaction: ret.transaction, + result, + errors, + }) + + await this.triggerParentStep(ret.transaction, result) + } + + if (throwOnError && ret.thrownError) { + throw ret.thrownError + } + + return { acknowledgement, ...ret } + } + async getRunningTransaction( workflowId: string, transactionId: string, - options?: WorkflowOrchestratorRunOptions, - @MedusaContext() sharedContext: Context = {} + options?: WorkflowOrchestratorRunOptions ): Promise { let { context, container } = options ?? {} @@ -251,19 +348,15 @@ export class WorkflowOrchestratorService { return transaction } - @InjectSharedContext() - async setStepSuccess( - { - idempotencyKey, - stepResponse, - options, - }: { - idempotencyKey: string | IdempotencyKeyParts - stepResponse: unknown - options?: RegisterStepSuccessOptions - }, - @MedusaContext() sharedContext: Context = {} - ) { + async setStepSuccess({ + idempotencyKey, + stepResponse, + options, + }: { + idempotencyKey: string | IdempotencyKeyParts + stepResponse: unknown + options?: RegisterStepSuccessOptions + }) { const { context, logOnError, @@ -321,19 +414,15 @@ export class WorkflowOrchestratorService { return ret } - @InjectSharedContext() - async setStepFailure( - { - idempotencyKey, - stepResponse, - options, - }: { - idempotencyKey: string | IdempotencyKeyParts - stepResponse: unknown - options?: RegisterStepSuccessOptions - }, - @MedusaContext() sharedContext: Context = {} - ) { + async setStepFailure({ + idempotencyKey, + stepResponse, + options, + }: { + idempotencyKey: string | IdempotencyKeyParts + stepResponse: unknown + options?: RegisterStepSuccessOptions + }) { const { context, logOnError, @@ -391,11 +480,12 @@ export class WorkflowOrchestratorService { return ret } - @InjectSharedContext() - subscribe( - { workflowId, transactionId, subscriber, subscriberId }: SubscribeOptions, - @MedusaContext() sharedContext: Context = {} - ) { + subscribe({ + workflowId, + transactionId, + subscriber, + subscriberId, + }: SubscribeOptions) { subscriber._id = subscriberId const subscribers = this.subscribers.get(workflowId) ?? new Map() @@ -427,11 +517,11 @@ export class WorkflowOrchestratorService { this.subscribers.set(workflowId, subscribers) } - @InjectSharedContext() - unsubscribe( - { workflowId, transactionId, subscriberOrId }: UnsubscribeOptions, - @MedusaContext() sharedContext: Context = {} - ) { + unsubscribe({ + workflowId, + transactionId, + subscriberOrId, + }: UnsubscribeOptions) { const subscribers = this.subscribers.get(workflowId) ?? new Map() const filterSubscribers = (handlers: SubscriberHandler[]) => { diff --git a/packages/modules/workflow-engine-inmemory/src/services/workflows-module.ts b/packages/modules/workflow-engine-inmemory/src/services/workflows-module.ts index 5d89c45767f1b..07d1c3aa5997b 100644 --- a/packages/modules/workflow-engine-inmemory/src/services/workflows-module.ts +++ b/packages/modules/workflow-engine-inmemory/src/services/workflows-module.ts @@ -62,7 +62,9 @@ export class WorkflowsModuleService< await this.clearExpiredExecutions() this.clearTimeout_ = setInterval(async () => { - await this.clearExpiredExecutions() + try { + await this.clearExpiredExecutions() + } catch {} }, 1000 * 60 * 60) }, onApplicationShutdown: async () => { @@ -80,11 +82,14 @@ export class WorkflowsModuleService< > = {}, @MedusaContext() context: Context = {} ) { + options ??= {} + options.context ??= context + const ret = await this.workflowOrchestratorService_.run< TWorkflow extends ReturnWorkflow ? UnwrapWorkflowInputDataType : unknown - >(workflowIdOrWorkflow, options, context) + >(workflowIdOrWorkflow, options) return ret as any } @@ -115,14 +120,14 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return await this.workflowOrchestratorService_.setStepSuccess( - { - idempotencyKey, - stepResponse, - options, - } as any, - context - ) + options ??= {} + options.context ??= context + + return await this.workflowOrchestratorService_.setStepSuccess({ + idempotencyKey, + stepResponse, + options, + } as any) } @InjectSharedContext() @@ -138,14 +143,14 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return await this.workflowOrchestratorService_.setStepFailure( - { - idempotencyKey, - stepResponse, - options, - } as any, - context - ) + options ??= {} + options.context ??= context + + return await this.workflowOrchestratorService_.setStepFailure({ + idempotencyKey, + stepResponse, + options, + } as any) } @InjectSharedContext() @@ -158,7 +163,7 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return this.workflowOrchestratorService_.subscribe(args as any, context) + return this.workflowOrchestratorService_.subscribe(args as any) } @InjectSharedContext() @@ -170,7 +175,7 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return this.workflowOrchestratorService_.unsubscribe(args as any, context) + return this.workflowOrchestratorService_.unsubscribe(args as any) } private async clearExpiredExecutions() { diff --git a/packages/modules/workflow-engine-inmemory/src/types/index.ts b/packages/modules/workflow-engine-inmemory/src/types/index.ts index e33cf96478e8d..7a7ac40112a40 100644 --- a/packages/modules/workflow-engine-inmemory/src/types/index.ts +++ b/packages/modules/workflow-engine-inmemory/src/types/index.ts @@ -1,5 +1,14 @@ +import { ContainerLike } from "@medusajs/framework" import { Logger } from "@medusajs/framework/types" +import { FlowCancelOptions } from "@medusajs/framework/workflows-sdk" export type InitializeModuleInjectableDependencies = { logger?: Logger } + +export type WorkflowOrchestratorCancelOptions = Omit< + FlowCancelOptions, + "transaction" | "container" +> & { + container?: ContainerLike +} diff --git a/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts b/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts index 0a3248084e1b7..2e8f81d0cfd20 100644 --- a/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts +++ b/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts @@ -3,12 +3,19 @@ import { IDistributedSchedulerStorage, IDistributedTransactionStorage, SchedulerOptions, + SkipExecutionError, TransactionCheckpoint, + TransactionFlow, TransactionOptions, TransactionStep, } from "@medusajs/framework/orchestration" import { Logger, ModulesSdkTypes } from "@medusajs/framework/types" -import { MedusaError, TransactionState } from "@medusajs/framework/utils" +import { + MedusaError, + TransactionState, + TransactionStepState, + isPresent, +} from "@medusajs/framework/utils" import { WorkflowOrchestratorService } from "@services" import { CronExpression, parseExpression } from "cron-parser" @@ -121,8 +128,6 @@ export class InMemoryDistributedTransactionStorage ttl?: number, options?: TransactionOptions ): Promise { - this.storage.set(key, data) - /** * Store the retention time only if the transaction is done, failed or reverted. * From that moment, this tuple can be later on archived or deleted after the retention time. @@ -135,11 +140,16 @@ export class InMemoryDistributedTransactionStorage const { retentionTime, idempotent } = options ?? {} - if (hasFinished) { - Object.assign(data, { - retention_time: retentionTime, - }) - } + await this.#preventRaceConditionExecutionIfNecessary({ + data, + key, + options, + }) + + Object.assign(data, { + retention_time: retentionTime, + }) + this.storage.set(key, data) if (hasFinished && !retentionTime && !idempotent) { await this.deleteFromDb(data) @@ -152,6 +162,118 @@ export class InMemoryDistributedTransactionStorage } } + async #preventRaceConditionExecutionIfNecessary({ + data, + key, + options, + }: { + data: TransactionCheckpoint + key: string + options?: TransactionOptions + }) { + let isInitialCheckpoint = false + + if (data.flow.state === TransactionState.NOT_STARTED) { + isInitialCheckpoint = true + } + + /** + * In case many execution can succeed simultaneously, we need to ensure that the latest + * execution does continue if a previous execution is considered finished + */ + const currentFlow = data.flow + const { flow: latestUpdatedFlow } = + (await this.get(key, options)) ?? + ({ flow: {} } as { flow: TransactionFlow }) + + if (!isInitialCheckpoint && !isPresent(latestUpdatedFlow)) { + /** + * the initial checkpoint expect no other checkpoint to have been stored. + * In case it is not the initial one and another checkpoint is trying to + * find if a concurrent execution has finished, we skip the execution. + * The already finished execution would have deleted the checkpoint already. + */ + throw new SkipExecutionError("Already finished by another execution") + } + + const currentFlowLastInvokingStepIndex = Object.values( + currentFlow.steps + ).findIndex((step) => { + return [ + TransactionStepState.INVOKING, + TransactionStepState.NOT_STARTED, + ].includes(step.invoke?.state) + }) + + const latestUpdatedFlowLastInvokingStepIndex = !latestUpdatedFlow.steps + ? 1 // There is no other execution, so the current execution is the latest + : Object.values( + (latestUpdatedFlow.steps as Record) ?? {} + ).findIndex((step) => { + return [ + TransactionStepState.INVOKING, + TransactionStepState.NOT_STARTED, + ].includes(step.invoke?.state) + }) + + const currentFlowLastCompensatingStepIndex = Object.values( + currentFlow.steps + ) + .reverse() + .findIndex((step) => { + return [ + TransactionStepState.COMPENSATING, + TransactionStepState.NOT_STARTED, + ].includes(step.compensate?.state) + }) + + const latestUpdatedFlowLastCompensatingStepIndex = !latestUpdatedFlow.steps + ? -1 // There is no other execution, so the current execution is the latest + : Object.values( + (latestUpdatedFlow.steps as Record) ?? {} + ) + .reverse() + .findIndex((step) => { + return [ + TransactionStepState.COMPENSATING, + TransactionStepState.NOT_STARTED, + ].includes(step.compensate?.state) + }) + + const isLatestExecutionFinishedIndex = -1 + const invokeShouldBeSkipped = + (latestUpdatedFlowLastInvokingStepIndex === + isLatestExecutionFinishedIndex || + currentFlowLastInvokingStepIndex < + latestUpdatedFlowLastInvokingStepIndex) && + currentFlowLastInvokingStepIndex !== isLatestExecutionFinishedIndex + + const compensateShouldBeSkipped = + currentFlowLastCompensatingStepIndex < + latestUpdatedFlowLastCompensatingStepIndex && + currentFlowLastCompensatingStepIndex !== isLatestExecutionFinishedIndex && + latestUpdatedFlowLastCompensatingStepIndex !== + isLatestExecutionFinishedIndex + + if ( + (data.flow.state !== TransactionState.COMPENSATING && + invokeShouldBeSkipped) || + (data.flow.state === TransactionState.COMPENSATING && + compensateShouldBeSkipped) || + (latestUpdatedFlow.state === TransactionState.COMPENSATING && + ![TransactionState.REVERTED, TransactionState.FAILED].includes( + currentFlow.state + ) && + currentFlow.state !== latestUpdatedFlow.state) || + (latestUpdatedFlow.state === TransactionState.REVERTED && + currentFlow.state !== TransactionState.REVERTED) || + (latestUpdatedFlow.state === TransactionState.FAILED && + currentFlow.state !== TransactionState.FAILED) + ) { + throw new SkipExecutionError("Already finished by another execution") + } + } + async scheduleRetry( transaction: DistributedTransactionType, step: TransactionStep, diff --git a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_step_timeout.ts b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_step_timeout.ts index 1d4807366de85..9710701487a9b 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_step_timeout.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_step_timeout.ts @@ -8,7 +8,7 @@ import { setTimeout } from "timers/promises" const step_1 = createStep( "step_1", jest.fn(async (input) => { - await setTimeout(200) + await setTimeout(1000) return new StepResponse(input, { compensate: 123 }) }) @@ -42,6 +42,8 @@ createWorkflow( createWorkflow( { name: "workflow_step_timeout_async", + idempotent: true, + retentionTime: 5, }, function (input) { const resp = step_1_async(input) diff --git a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_transaction_timeout.ts b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_transaction_timeout.ts index b713b1d437dea..14ad92ae50a60 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_transaction_timeout.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_transaction_timeout.ts @@ -33,6 +33,8 @@ createWorkflow( { name: "workflow_transaction_timeout_async", timeout: 0.1, // 0.1 second + idempotent: true, + retentionTime: 5, }, function (input) { const resp = step_1(input).config({ diff --git a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts index d536aeccadb44..a3410fb009b8a 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts @@ -335,7 +335,7 @@ moduleIntegrationTestRunner({ throwOnError: false, }) - await setTimeout(200) + await setTimeout(2000) const { transaction, result, errors } = (await workflowOrcModule.run( "workflow_step_timeout_async", @@ -569,7 +569,6 @@ moduleIntegrationTestRunner({ ) }) - // TODO: investigate why it fails intermittently it.skip("the scheduled workflow should have access to the shared container", async () => { const wait = times(1) sharedContainer_.register("test-value", asValue("test")) diff --git a/packages/modules/workflow-engine-redis/integration-tests/__tests__/race.spec.ts b/packages/modules/workflow-engine-redis/integration-tests/__tests__/race.spec.ts new file mode 100644 index 0000000000000..1304b6f8fc794 --- /dev/null +++ b/packages/modules/workflow-engine-redis/integration-tests/__tests__/race.spec.ts @@ -0,0 +1,203 @@ +import { IWorkflowEngineService } from "@medusajs/framework/types" +import { Modules } from "@medusajs/framework/utils" +import { + createStep, + createWorkflow, + StepResponse, + transform, + WorkflowResponse, +} from "@medusajs/framework/workflows-sdk" +import { moduleIntegrationTestRunner } from "@medusajs/test-utils" +import { setTimeout as setTimeoutSync } from "timers" +import { setTimeout } from "timers/promises" +import "../__fixtures__" + +jest.setTimeout(999900000) + +const failTrap = (done) => { + setTimeoutSync(() => { + // REF:https://stackoverflow.com/questions/78028715/jest-async-test-with-event-emitter-isnt-ending + console.warn( + "Jest is breaking the event emit with its debouncer. This allows to continue the test by managing the timeout of the test manually." + ) + done() + }, 5000) +} + +// REF:https://stackoverflow.com/questions/78028715/jest-async-test-with-event-emitter-isnt-ending + +moduleIntegrationTestRunner({ + moduleName: Modules.WORKFLOW_ENGINE, + resolve: __dirname + "/../..", + moduleOptions: { + redis: { + url: "localhost:6379", + }, + }, + testSuite: ({ service: workflowOrcModule, medusaApp }) => { + describe("Testing race condition of the workflow during retry", () => { + it("should prevent race continuation of the workflow during retryIntervalAwaiting in background execution", (done) => { + const transactionId = "transaction_id" + + const step0InvokeMock = jest.fn() + const step1InvokeMock = jest.fn() + const step2InvokeMock = jest.fn() + const transformMock = jest.fn() + + const step0 = createStep("step0", async (_) => { + step0InvokeMock() + return new StepResponse("result from step 0") + }) + + const step1 = createStep("step1", async (_) => { + step1InvokeMock() + await setTimeout(2000) + return new StepResponse({ isSuccess: true }) + }) + + const step2 = createStep("step2", async (input: any) => { + step2InvokeMock() + return new StepResponse({ result: input }) + }) + + const subWorkflow = createWorkflow("sub-workflow-1", function () { + const status = step1() + return new WorkflowResponse(status) + }) + + createWorkflow("workflow-1", function () { + const build = step0() + + const status = subWorkflow.runAsStep({} as any).config({ + async: true, + compensateAsync: true, + backgroundExecution: true, + retryIntervalAwaiting: 1, + }) + + const transformedResult = transform({ status }, (data) => { + transformMock() + return { + status: data.status, + } + }) + + step2(transformedResult) + return new WorkflowResponse(build) + }) + + void workflowOrcModule.subscribe({ + workflowId: "workflow-1", + transactionId, + subscriber: (event) => { + if (event.eventType === "onFinish") { + expect(step0InvokeMock).toHaveBeenCalledTimes(1) + expect(step1InvokeMock.mock.calls.length).toBeGreaterThan(1) + expect(step2InvokeMock).toHaveBeenCalledTimes(1) + expect(transformMock).toHaveBeenCalledTimes(1) + setTimeoutSync(done, 500) + } + }, + }) + + workflowOrcModule + .run("workflow-1", { transactionId }) + .then(({ result }) => { + expect(result).toBe("result from step 0") + }) + + failTrap(done) + }) + + it("should prevent race continuation of the workflow compensation during retryIntervalAwaiting in background execution", (done) => { + const transactionId = "transaction_id" + const workflowId = "RACE_workflow-1" + + const step0InvokeMock = jest.fn() + const step0CompensateMock = jest.fn() + const step1InvokeMock = jest.fn() + const step1CompensateMock = jest.fn() + const step2InvokeMock = jest.fn() + const transformMock = jest.fn() + + const step0 = createStep( + "RACE_step0", + async (_) => { + step0InvokeMock() + return new StepResponse("result from step 0") + }, + () => { + step0CompensateMock() + } + ) + + const step1 = createStep( + "RACE_step1", + async (_) => { + step1InvokeMock() + await setTimeout(500) + throw new Error("error from step 1") + }, + () => { + step1CompensateMock() + } + ) + + const step2 = createStep("RACE_step2", async (input: any) => { + step2InvokeMock() + return new StepResponse({ result: input }) + }) + + const subWorkflow = createWorkflow("RACE_sub-workflow-1", function () { + const status = step1() + return new WorkflowResponse(status) + }) + + createWorkflow(workflowId, function () { + const build = step0() + + const status = subWorkflow.runAsStep({} as any).config({ + async: true, + compensateAsync: true, + backgroundExecution: true, + retryIntervalAwaiting: 0.1, + }) + + const transformedResult = transform({ status }, (data) => { + transformMock() + return { + status: data.status, + } + }) + + step2(transformedResult) + return new WorkflowResponse(build) + }) + + void workflowOrcModule.subscribe({ + workflowId: workflowId, + transactionId, + subscriber: (event) => { + if (event.eventType === "onFinish") { + expect(step0InvokeMock).toHaveBeenCalledTimes(1) + expect(step0CompensateMock).toHaveBeenCalledTimes(1) + expect(step1InvokeMock.mock.calls.length).toBeGreaterThan(2) + expect(step1CompensateMock).toHaveBeenCalledTimes(1) + expect(step2InvokeMock).toHaveBeenCalledTimes(0) + expect(transformMock).toHaveBeenCalledTimes(0) + done() + } + }, + }) + + workflowOrcModule + .run(workflowId, { transactionId }) + .then(({ result }) => { + expect(result).toBe("result from step 0") + }) + + failTrap(done) + }) + }) + }, +}) diff --git a/packages/modules/workflow-engine-redis/src/services/workflow-orchestrator.ts b/packages/modules/workflow-engine-redis/src/services/workflow-orchestrator.ts index 43788f11884bd..b5293cc4828af 100644 --- a/packages/modules/workflow-engine-redis/src/services/workflow-orchestrator.ts +++ b/packages/modules/workflow-engine-redis/src/services/workflow-orchestrator.ts @@ -12,13 +12,9 @@ import { Logger, MedusaContainer, } from "@medusajs/framework/types" +import { isString, TransactionState } from "@medusajs/framework/utils" import { - InjectSharedContext, - isString, - MedusaContext, - TransactionState, -} from "@medusajs/framework/utils" -import { + FlowCancelOptions, FlowRunOptions, MedusaWorkflow, resolveValue, @@ -37,6 +33,11 @@ export type WorkflowOrchestratorRunOptions = Omit< container?: ContainerLike } +export type WorkflowOrchestratorCancelOptions = Omit< + FlowCancelOptions, + "transaction" +> + type RegisterStepSuccessOptions = Omit< WorkflowOrchestratorRunOptions, "transactionId" | "input" @@ -183,11 +184,9 @@ export class WorkflowOrchestratorService { } } - @InjectSharedContext() async run( workflowIdOrWorkflow: string | ReturnWorkflow, - options?: WorkflowOrchestratorRunOptions, - @MedusaContext() sharedContext: Context = {} + options?: WorkflowOrchestratorRunOptions ) { const { input, @@ -239,6 +238,7 @@ export class WorkflowOrchestratorService { const hasFinished = ret.transaction.hasFinished() const metadata = ret.transaction.getFlow().metadata const { parentStepIdempotencyKey } = metadata ?? {} + const hasFailed = [ TransactionState.REVERTED, TransactionState.FAILED, @@ -271,12 +271,115 @@ export class WorkflowOrchestratorService { return { acknowledgement, ...ret } } - @InjectSharedContext() + async cancel( + workflowIdOrWorkflow: string | ReturnWorkflow, + options?: WorkflowOrchestratorCancelOptions + ) { + const { + transactionId, + logOnError, + events: eventHandlers, + container, + } = options ?? {} + + let { throwOnError, context } = options ?? {} + + throwOnError ??= true + context ??= {} + + const workflowId = isString(workflowIdOrWorkflow) + ? workflowIdOrWorkflow + : workflowIdOrWorkflow.getName() + + if (!workflowId) { + throw new Error("Workflow ID is required") + } + + if (!transactionId) { + throw new Error("Transaction ID is required") + } + + const events: FlowRunOptions["events"] = this.buildWorkflowEvents({ + customEventHandlers: eventHandlers, + workflowId, + transactionId: transactionId, + }) + + const exportedWorkflow = MedusaWorkflow.getWorkflow(workflowId) + if (!exportedWorkflow) { + throw new Error(`Workflow with id "${workflowId}" not found.`) + } + + const originalOnFinishHandler = events.onFinish! + delete events.onFinish + + const transaction = await this.getRunningTransaction( + workflowId, + transactionId, + options + ) + if (!transaction) { + if (!throwOnError) { + return { + acknowledgement: { + transactionId, + workflowId, + exists: false, + }, + } + } + throw new Error("Transaction not found") + } + + const ret = await exportedWorkflow.cancel({ + transaction, + throwOnError: false, + logOnError, + context, + events, + container: container ?? this.container_, + }) + + const hasFinished = ret.transaction.hasFinished() + const metadata = ret.transaction.getFlow().metadata + const { parentStepIdempotencyKey } = metadata ?? {} + + const hasFailed = [TransactionState.FAILED].includes( + ret.transaction.getFlow().state + ) + + const acknowledgement = { + transactionId: context.transactionId, + workflowId: workflowId, + parentStepIdempotencyKey, + hasFinished, + hasFailed, + exists: true, + } + + if (hasFinished) { + const { result, errors } = ret + + await originalOnFinishHandler({ + transaction: ret.transaction, + result, + errors, + }) + + await this.triggerParentStep(ret.transaction, result) + } + + if (throwOnError && ret.thrownError) { + throw ret.thrownError + } + + return { acknowledgement, ...ret } + } + async getRunningTransaction( workflowId: string, transactionId: string, - options?: WorkflowOrchestratorRunOptions, - @MedusaContext() sharedContext: Context = {} + options?: { context?: Context } ): Promise { let { context } = options ?? {} @@ -289,7 +392,6 @@ export class WorkflowOrchestratorService { } context ??= {} - context.transactionId ??= transactionId const exportedWorkflow: any = MedusaWorkflow.getWorkflow(workflowId) if (!exportedWorkflow) { @@ -304,19 +406,15 @@ export class WorkflowOrchestratorService { return transaction } - @InjectSharedContext() - async setStepSuccess( - { - idempotencyKey, - stepResponse, - options, - }: { - idempotencyKey: string | IdempotencyKeyParts - stepResponse: unknown - options?: RegisterStepSuccessOptions - }, - @MedusaContext() sharedContext: Context = {} - ) { + async setStepSuccess({ + idempotencyKey, + stepResponse, + options, + }: { + idempotencyKey: string | IdempotencyKeyParts + stepResponse: unknown + options?: RegisterStepSuccessOptions + }) { const { context, logOnError, @@ -375,19 +473,15 @@ export class WorkflowOrchestratorService { return ret } - @InjectSharedContext() - async setStepFailure( - { - idempotencyKey, - stepResponse, - options, - }: { - idempotencyKey: string | IdempotencyKeyParts - stepResponse: unknown - options?: RegisterStepSuccessOptions - }, - @MedusaContext() sharedContext: Context = {} - ) { + async setStepFailure({ + idempotencyKey, + stepResponse, + options, + }: { + idempotencyKey: string | IdempotencyKeyParts + stepResponse: unknown + options?: RegisterStepSuccessOptions + }) { const { context, logOnError, @@ -446,11 +540,12 @@ export class WorkflowOrchestratorService { return ret } - @InjectSharedContext() - subscribe( - { workflowId, transactionId, subscriber, subscriberId }: SubscribeOptions, - @MedusaContext() sharedContext: Context = {} - ) { + subscribe({ + workflowId, + transactionId, + subscriber, + subscriberId, + }: SubscribeOptions) { subscriber._id = subscriberId const subscribers = this.subscribers.get(workflowId) ?? new Map() @@ -487,11 +582,11 @@ export class WorkflowOrchestratorService { this.subscribers.set(workflowId, subscribers) } - @InjectSharedContext() - unsubscribe( - { workflowId, transactionId, subscriberOrId }: UnsubscribeOptions, - @MedusaContext() sharedContext: Context = {} - ) { + unsubscribe({ + workflowId, + transactionId, + subscriberOrId, + }: UnsubscribeOptions) { const subscribers = this.subscribers.get(workflowId) ?? new Map() const filterSubscribers = (handlers: SubscriberHandler[]) => { diff --git a/packages/modules/workflow-engine-redis/src/services/workflows-module.ts b/packages/modules/workflow-engine-redis/src/services/workflows-module.ts index 4421f659e20f1..2795c4f7e506a 100644 --- a/packages/modules/workflow-engine-redis/src/services/workflows-module.ts +++ b/packages/modules/workflow-engine-redis/src/services/workflows-module.ts @@ -75,7 +75,9 @@ export class WorkflowsModuleService< await this.clearExpiredExecutions() this.clearTimeout_ = setInterval(async () => { - await this.clearExpiredExecutions() + try { + await this.clearExpiredExecutions() + } catch {} }, 1000 * 60 * 60) }, } @@ -90,11 +92,13 @@ export class WorkflowsModuleService< > = {}, @MedusaContext() context: Context = {} ) { + options ??= {} + options.context ??= context const ret = await this.workflowOrchestratorService_.run< TWorkflow extends ReturnWorkflow ? UnwrapWorkflowInputDataType : unknown - >(workflowIdOrWorkflow, options, context) + >(workflowIdOrWorkflow, options) return ret as any } @@ -108,7 +112,7 @@ export class WorkflowsModuleService< return await this.workflowOrchestratorService_.getRunningTransaction( workflowId, transactionId, - context + { context } ) } @@ -125,14 +129,14 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return await this.workflowOrchestratorService_.setStepSuccess( - { - idempotencyKey, - stepResponse, - options, - } as any, - context - ) + options ??= {} + options.context ??= context + + return await this.workflowOrchestratorService_.setStepSuccess({ + idempotencyKey, + stepResponse, + options, + } as any) } @InjectSharedContext() @@ -148,14 +152,14 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return await this.workflowOrchestratorService_.setStepFailure( - { - idempotencyKey, - stepResponse, - options, - } as any, - context - ) + options ??= {} + options.context ??= context + + return await this.workflowOrchestratorService_.setStepFailure({ + idempotencyKey, + stepResponse, + options, + } as any) } @InjectSharedContext() @@ -168,7 +172,7 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return this.workflowOrchestratorService_.subscribe(args as any, context) + return this.workflowOrchestratorService_.subscribe(args as any) } @InjectSharedContext() @@ -180,7 +184,7 @@ export class WorkflowsModuleService< }, @MedusaContext() context: Context = {} ) { - return this.workflowOrchestratorService_.unsubscribe(args as any, context) + return this.workflowOrchestratorService_.unsubscribe(args as any) } private async clearExpiredExecutions() { diff --git a/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts b/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts index 8365160509550..80b1c5350b056 100644 --- a/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts +++ b/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts @@ -4,15 +4,19 @@ import { IDistributedSchedulerStorage, IDistributedTransactionStorage, SchedulerOptions, + SkipExecutionError, TransactionCheckpoint, + TransactionFlow, TransactionOptions, TransactionStep, } from "@medusajs/framework/orchestration" import { Logger, ModulesSdkTypes } from "@medusajs/framework/types" import { + isPresent, MedusaError, promiseAll, TransactionState, + TransactionStepState, } from "@medusajs/framework/utils" import { WorkflowOrchestratorService } from "@services" import { Queue, Worker } from "bullmq" @@ -28,7 +32,6 @@ enum JobType { export class RedisDistributedTransactionStorage implements IDistributedTransactionStorage, IDistributedSchedulerStorage { - private static TTL_AFTER_COMPLETED = 60 * 2 // 2 minutes private workflowExecutionService_: ModulesSdkTypes.IMedusaInternalService private logger_: Logger private workflowOrchestratorService_: WorkflowOrchestratorService @@ -263,6 +266,12 @@ export class RedisDistributedTransactionStorage const { retentionTime, idempotent } = options ?? {} + await this.#preventRaceConditionExecutionIfNecessary({ + data, + key, + options, + }) + if (hasFinished) { Object.assign(data, { retention_time: retentionTime, @@ -286,12 +295,7 @@ export class RedisDistributedTransactionStorage } if (hasFinished) { - await this.redisClient.set( - key, - stringifiedData, - "EX", - RedisDistributedTransactionStorage.TTL_AFTER_COMPLETED - ) + await this.redisClient.unlink(key) } } @@ -457,4 +461,116 @@ export class RedisDistributedTransactionStorage repeatableJobs.map((job) => this.jobQueue?.removeRepeatableByKey(job.key)) ) } + + async #preventRaceConditionExecutionIfNecessary({ + data, + key, + options, + }: { + data: TransactionCheckpoint + key: string + options?: TransactionOptions + }) { + let isInitialCheckpoint = false + + if (data.flow.state === TransactionState.NOT_STARTED) { + isInitialCheckpoint = true + } + + /** + * In case many execution can succeed simultaneously, we need to ensure that the latest + * execution does continue if a previous execution is considered finished + */ + const currentFlow = data.flow + const { flow: latestUpdatedFlow } = + (await this.get(key, options)) ?? + ({ flow: {} } as { flow: TransactionFlow }) + + if (!isInitialCheckpoint && !isPresent(latestUpdatedFlow)) { + /** + * the initial checkpoint expect no other checkpoint to have been stored. + * In case it is not the initial one and another checkpoint is trying to + * find if a concurrent execution has finished, we skip the execution. + * The already finished execution would have deleted the checkpoint already. + */ + throw new SkipExecutionError("Already finished by another execution") + } + + const currentFlowLastInvokingStepIndex = Object.values( + currentFlow.steps + ).findIndex((step) => { + return [ + TransactionStepState.INVOKING, + TransactionStepState.NOT_STARTED, + ].includes(step.invoke?.state) + }) + + const latestUpdatedFlowLastInvokingStepIndex = !latestUpdatedFlow.steps + ? 1 // There is no other execution, so the current execution is the latest + : Object.values( + (latestUpdatedFlow.steps as Record) ?? {} + ).findIndex((step) => { + return [ + TransactionStepState.INVOKING, + TransactionStepState.NOT_STARTED, + ].includes(step.invoke?.state) + }) + + const currentFlowLastCompensatingStepIndex = Object.values( + currentFlow.steps + ) + .reverse() + .findIndex((step) => { + return [ + TransactionStepState.COMPENSATING, + TransactionStepState.NOT_STARTED, + ].includes(step.compensate?.state) + }) + + const latestUpdatedFlowLastCompensatingStepIndex = !latestUpdatedFlow.steps + ? -1 + : Object.values( + (latestUpdatedFlow.steps as Record) ?? {} + ) + .reverse() + .findIndex((step) => { + return [ + TransactionStepState.COMPENSATING, + TransactionStepState.NOT_STARTED, + ].includes(step.compensate?.state) + }) + + const isLatestExecutionFinishedIndex = -1 + const invokeShouldBeSkipped = + (latestUpdatedFlowLastInvokingStepIndex === + isLatestExecutionFinishedIndex || + currentFlowLastInvokingStepIndex < + latestUpdatedFlowLastInvokingStepIndex) && + currentFlowLastInvokingStepIndex !== isLatestExecutionFinishedIndex + + const compensateShouldBeSkipped = + currentFlowLastCompensatingStepIndex < + latestUpdatedFlowLastCompensatingStepIndex && + currentFlowLastCompensatingStepIndex !== isLatestExecutionFinishedIndex && + latestUpdatedFlowLastCompensatingStepIndex !== + isLatestExecutionFinishedIndex + + if ( + (data.flow.state !== TransactionState.COMPENSATING && + invokeShouldBeSkipped) || + (data.flow.state === TransactionState.COMPENSATING && + compensateShouldBeSkipped) || + (latestUpdatedFlow.state === TransactionState.COMPENSATING && + ![TransactionState.REVERTED, TransactionState.FAILED].includes( + currentFlow.state + ) && + currentFlow.state !== latestUpdatedFlow.state) || + (latestUpdatedFlow.state === TransactionState.REVERTED && + currentFlow.state !== TransactionState.REVERTED) || + (latestUpdatedFlow.state === TransactionState.FAILED && + currentFlow.state !== TransactionState.FAILED) + ) { + throw new SkipExecutionError("Already finished by another execution") + } + } }