Skip to content

Commit

Permalink
Passing tenantId to mlclient getTask for multiTenancy feature
Browse files Browse the repository at this point in the history
Signed-off-by: Siddhartha Bingi <[email protected]>
  • Loading branch information
Siddhartha Bingi committed Feb 4, 2025
1 parent 20a1d40 commit 28b57f3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void retryableGetMlTask(
) {
CompletableFuture.runAsync(() -> {
do {
mlClient.getTask(taskId, ActionListener.wrap(response -> {
mlClient.getTask(taskId, tenantId, ActionListener.wrap(response -> {
String resourceName = getResourceByWorkflowStep(getName());
String id = getResourceId(response);
switch (response.getState()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
Expand All @@ -152,7 +152,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
future.actionGet();

verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -151,7 +152,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
Expand All @@ -172,7 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
future.actionGet();

verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -144,7 +145,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
Expand All @@ -165,7 +166,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
future.actionGet();

verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -147,7 +148,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
Expand All @@ -168,7 +169,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
future.actionGet();

verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
Expand Down

0 comments on commit 28b57f3

Please sign in to comment.