diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java index 9f64afebb..d07be2070 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java @@ -11,98 +11,50 @@ package org.opensearch.knn.plugin.transport; -import org.junit.After; -import org.junit.Before; -import org.opensearch.core.action.ActionListener; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.knn.KNNSingleNodeTestCase; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; -import org.opensearch.knn.training.TrainingJob; +import org.mockito.MockedStatic; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.concurrent.*; +import java.util.Collections; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL; - -public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNSingleNodeTestCase { - - ExecutorService executorService; - - @Before - public void setup() { - executorService = Executors.newSingleThreadExecutor(); - } - - @After - public void teardown() { - executorService.shutdown(); - } - - @SuppressWarnings("unchecked") - public void testNodeOperation() throws IOException, InterruptedException, ExecutionException { - // Ensure initial value of train job count is 0 - TrainingJobRouteDecisionInfoTransportAction action = node().injector() - .getInstance(TrainingJobRouteDecisionInfoTransportAction.class); - - TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest(); - - TrainingJobRouteDecisionInfoNodeResponse response1 = action.nodeOperation(request); - assertEquals(0, response1.getTrainingJobCount().intValue()); - - // Setup mocked training job - String modelId = "model-id"; - Model model = mock(Model.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); - when(model.getModelMetadata()).thenReturn(modelMetadata); - TrainingJob trainingJob = mock(TrainingJob.class); - when(trainingJob.getModelId()).thenReturn(modelId); - when(trainingJob.getModel()).thenReturn(model); - doAnswer(invocationOnMock -> null).when(trainingJob).run(); - - ModelDao modelDao = mock(ModelDao.class); - when(modelDao.get(modelId)).thenReturn(model); - - // Here we check to make sure there is a running job - doAnswer(invocationOnMock -> { - TrainingJobRouteDecisionInfoNodeResponse response2 = action.nodeOperation(request); - assertEquals(1, response2.getTrainingJobCount().intValue()); - - IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true); - ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); - return null; - }).when(modelDao).put(any(Model.class), any(ActionListener.class)); - - // Set up the rest of the training logic - final CountDownLatch inProgressLatch = new CountDownLatch(1); - ActionListener responseListener = ActionListener.wrap( - indexResponse -> { inProgressLatch.countDown(); }, - e -> fail("Failure should not have occurred") - ); - - doAnswer(invocationOnMock -> { - responseListener.onResponse(mock(IndexResponse.class)); - return null; - }).when(modelDao).update(model, responseListener); +public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNTestCase { + public void testNodeOperation() { + // Initialize mocked variables for the class + DiscoveryNode node = mock(DiscoveryNode.class); + when(clusterService.localNode()).thenReturn(node); ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService); - - // Initialize runner and execute job - TrainingJobRunner.initialize(threadPool, modelDao); - TrainingJobRunner.getInstance().execute(trainingJob, responseListener); - - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + TransportService transportService = mock(TransportService.class); + doNothing().when(transportService).registerRequestHandler(any(), any(), any(), any()); + ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); + + TrainingJobRouteDecisionInfoTransportAction trainingJobRouteDecisionInfoTransportAction = + new TrainingJobRouteDecisionInfoTransportAction(threadPool, clusterService, transportService, actionFilters); + + try (MockedStatic mockedTrainingJobRunnerStatic = mockStatic(TrainingJobRunner.class)) { + // Ensure the job count is correct + int initialJobCount = 4; + final TrainingJobRunner mockedTrainingJobRunner = mock(TrainingJobRunner.class); + when(mockedTrainingJobRunner.getJobCount()).thenReturn(initialJobCount); + mockedTrainingJobRunnerStatic.when(TrainingJobRunner::getInstance).thenReturn(mockedTrainingJobRunner); + + TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest(); + TrainingJobRouteDecisionInfoNodeResponse response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request); + assertEquals(initialJobCount, response.getTrainingJobCount().intValue()); + + int resetJobCount = 0; + when(mockedTrainingJobRunner.getJobCount()).thenReturn(resetJobCount); + response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request); + assertEquals(resetJobCount, response.getTrainingJobCount().intValue()); + } } }