Skip to content

Commit

Permalink
Refactor RetryableWorkflowStep to avoid recursion and subclass specif…
Browse files Browse the repository at this point in the history
…ics (#298)

* Refactor RetryableWorkflowStep to avoid recursion and subclass specifics

Signed-off-by: Daniel Widdis <[email protected]>

* Fix tests

Signed-off-by: Daniel Widdis <[email protected]>

* Support function_name field in Register Local Model Step

Signed-off-by: Daniel Widdis <[email protected]>

* Better handling of provisioning exception message on cancellation

Signed-off-by: Daniel Widdis <[email protected]>

* Throw an exception with message on cancellation.

Signed-off-by: Daniel Widdis <[email protected]>

* Add test coverage

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
(cherry picked from commit 39c3f48)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Jan 2, 2024
1 parent 1b2a219 commit 2b237e9
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public Collection<Object> createComponents(
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
threadPool,
clusterService,
client,
mlClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -213,12 +214,19 @@ private void executeWorkflow(List<ProcessNode> workflowSequence, String workflow
);
} catch (Exception ex) {
logger.error("Provisioning failed for workflow: {}", workflowId, ex);
String errorMessage;
if (ex instanceof CancellationException) {
errorMessage = "A step in the workflow was cancelled.";

Check warning on line 219 in src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java#L219

Added line #L219 was not covered by tests
} else if (ex.getCause() != null) {
errorMessage = ex.getCause().getMessage();

Check warning on line 221 in src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java#L221

Added line #L221 was not covered by tests
} else {
errorMessage = ex.getMessage();

Check warning on line 223 in src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java#L223

Added line #L223 was not covered by tests
}
flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
Map.ofEntries(
Map.entry(STATE_FIELD, State.FAILED),
// TODO: potentially improve the error message here
Map.entry(ERROR_FIELD, ex.getMessage()),
Map.entry(ERROR_FIELD, errorMessage),

Check warning on line 229 in src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java#L229

Added line #L229 was not covered by tests
Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.FAILED),
Map.entry(PROVISION_END_TIME_FIELD, Instant.now().toEpochMilli())
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTask;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

/**
Expand All @@ -39,20 +40,24 @@ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep {
protected volatile Integer maxRetry;
private final MachineLearningNodeClient mlClient;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private ThreadPool threadPool;

/**
* Instantiates a new Retryable workflow step
* @param settings Environment settings
* @param threadPool The OpenSearch thread pool
* @param clusterService the cluster service
* @param mlClient machine learning client
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
protected AbstractRetryableWorkflowStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.threadPool = threadPool;
this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it);
this.mlClient = mlClient;
Expand All @@ -65,82 +70,93 @@ protected AbstractRetryableWorkflowStep(
* @param nodeId the workflow node id
* @param future the workflow step future
* @param taskId the ml task id
* @param retries the current number of request retries
* @param workflowStep the workflow step which requires a retry get ml task functionality
*/
protected void retryableGetMlTask(
String workflowId,
String nodeId,
CompletableFuture<WorkflowData> future,
String taskId,
int retries,
String workflowStep
) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
MLTaskState currentState = response.getState();
if (currentState != MLTaskState.COMPLETED) {
if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) {
// Model registration failed or completed with errors
String errorMessage = workflowStep + " failed with error : " + response.getError();
AtomicInteger retries = new AtomicInteger();
CompletableFuture.runAsync(() -> {
while (retries.getAndIncrement() < this.maxRetry && !future.isDone()) {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
switch (response.getState()) {
case COMPLETED:
try {
String resourceName = getResourceByWorkflowStep(getName());
String id = getResourceId(response);
logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id);
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, id),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))

Check warning on line 112 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L110-L112

Added lines #L110 - L112 were not covered by tests
);
})

Check warning on line 114 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L114

Added line #L114 was not covered by tests
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));

Check warning on line 118 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L116-L118

Added lines #L116 - L118 were not covered by tests
}
break;

Check warning on line 120 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L120

Added line #L120 was not covered by tests
case FAILED:
case COMPLETED_WITH_ERROR:
String errorMessage = workflowStep + " failed with error : " + response.getError();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
break;
case CANCELLED:
errorMessage = workflowStep + " task was cancelled.";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
break;

Check warning on line 131 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L128-L131

Added lines #L128 - L131 were not covered by tests
default:
// Task started or running, do nothing
}
}, exception -> {
String errorMessage = workflowStep + " failed with error : " + exception.getMessage();

Check warning on line 136 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L136

Added line #L136 was not covered by tests
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
} else {
// Task still in progress, attempt retry
throw new IllegalStateException(workflowStep + " is not yet completed");
}
} else {
try {
logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId());
String resourceName = getResourceByWorkflowStep(getName());
String id;
if (getName().equals(DEPLOY_MODEL.getWorkflowStep())) {
id = response.getModelId();
} else {
id = response.getTaskId();
}
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, response.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}
}, exception -> {
if (retries < maxRetry) {
// Sleep thread prior to retrying request
}));

Check warning on line 139 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L139

Added line #L139 was not covered by tests
// Wait long enough for future to possibly complete
try {
Thread.sleep(5000);
} catch (Exception e) {
} catch (InterruptedException e) {
FutureUtils.cancel(future);
Thread.currentThread().interrupt();
}
retryableGetMlTask(workflowId, nodeId, future, taskId, retries + 1, workflowStep);
} else {
logger.error("Failed to retrieve" + workflowStep + ",maximum retries exceeded");
future.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
}
}));
if (!future.isDone()) {
String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));

Check warning on line 151 in src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java#L149-L151

Added lines #L149 - L151 were not covered by tests
}
}, threadPool.executor(PROVISION_THREAD_POOL));
}

/**
* Returns the resourceId associated with the task
* @param response The Task response
* @return the resource ID, such as a model id
*/
protected abstract String getResourceId(MLTask response);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.threadpool.ThreadPool;

import java.util.Collections;
import java.util.Map;
Expand All @@ -42,17 +44,19 @@ public class DeployModelStep extends AbstractRetryableWorkflowStep {
/**
* Instantiate this class
* @param settings The OpenSearch settings
* @param threadPool The OpenSearch thread pool
* @param clusterService The cluster service
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public DeployModelStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(settings, clusterService, mlClient, flowFrameworkIndicesHandler);
super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler);
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}
Expand All @@ -74,7 +78,7 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
String taskId = mlDeployModelResponse.getTaskId();

// Attempt to retrieve the model ID
retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, 0, "Deploy model");
retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, "Deploy model");
}

@Override
Expand Down Expand Up @@ -105,6 +109,11 @@ public void onFailure(Exception e) {
return deployModelFuture;
}

@Override
protected String getResourceId(MLTask response) {
return response.getModelId();
}

@Override
public String getName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand All @@ -26,6 +28,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.Set;
Expand All @@ -35,6 +38,7 @@
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
Expand All @@ -60,17 +64,19 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep {
/**
* Instantiate this class
* @param settings The OpenSearch settings
* @param threadPool The OpenSearch thread pool
* @param clusterService The cluster service
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public RegisterLocalModelStep(
Settings settings,
ThreadPool threadPool,
ClusterService clusterService,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(settings, clusterService, mlClient, flowFrameworkIndicesHandler);
super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler);
this.mlClient = mlClient;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}
Expand Down Expand Up @@ -98,7 +104,6 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
currentNodeId,
registerLocalModelFuture,
taskId,
0,
"Local model registration"
);
}
Expand All @@ -120,7 +125,7 @@ public void onFailure(Exception e) {
MODEL_CONTENT_HASH_VALUE,
URL
);
Set<String> optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG);
Set<String> optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, FUNCTION_NAME);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -142,6 +147,7 @@ public void onFailure(Exception e) {
FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE));
String allConfig = (String) inputs.get(ALL_CONFIG);
String url = (String) inputs.get(URL);
String functionName = (String) inputs.get(FUNCTION_NAME);

// Create Model configuration
TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder()
Expand All @@ -158,13 +164,18 @@ public void onFailure(Exception e) {
.modelName(modelName)
.version(modelVersion)
.modelFormat(modelFormat)
.modelGroupId(modelGroupId)
.hashValue(modelContentHashValue)
.modelConfig(modelConfig)
.url(url);
if (description != null) {
mlInputBuilder.description(description);
}
if (modelGroupId != null) {
mlInputBuilder.modelGroupId(modelGroupId);
}
if (functionName != null) {
mlInputBuilder.functionName(FunctionName.from(functionName));
}

MLRegisterModelInput mlInput = mlInputBuilder.build();

Expand All @@ -175,6 +186,11 @@ public void onFailure(Exception e) {
return registerLocalModelFuture;
}

@Override
protected String getResourceId(MLTask response) {
return response.getModelId();
}

@Override
public String getName() {
return NAME;
Expand Down
Loading

0 comments on commit 2b237e9

Please sign in to comment.