diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index c251e44c..68b17caf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -82,7 +82,7 @@ protected void retryableGetMlTask( ) { CompletableFuture.runAsync(() -> { do { - mlClient.getTask(taskId, ActionListener.wrap(response -> { + mlClient.getTask(taskId, tenantId, ActionListener.wrap(response -> { String resourceName = getResourceByWorkflowStep(getName()); String id = getResourceId(response); switch (response.getState()) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 4cb3f8fa..9a6c40ac 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -131,7 +131,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).getTask(any(), any()); + }).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any()); doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(5); @@ -152,7 +152,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I future.actionGet(); verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 26a6bc67..af5a698e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -54,6 +54,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -151,7 +152,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).getTask(any(), any()); + }).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any()); doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(5); @@ -172,7 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 3b82dde9..51527ddc 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -53,6 +53,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -144,7 +145,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).getTask(any(), any()); + }).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any()); doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(5); @@ -165,7 +166,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 97c8a9b0..dac9c1ce 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -53,6 +53,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -147,7 +148,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).getTask(any(), any()); + }).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any()); doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(5); @@ -168,7 +169,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));