From 63acc3c31a92726d91e0d1bb515823b59a7ef26a Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 4 Jan 2024 14:09:43 -0800 Subject: [PATCH] Fix TrainingJobRouteDecisionInfo test Recently, we have seen that TrainingJobRouteDecisionInfoTransportActionTests has been having failures on Windows. The failures are related to an unintialized cluster state. This does not have anything to do with the test itself. Most likely, it is the result of state dependence that happens with KNNSingleNodeTestCase. This change refactors the class to use mocks and a lighter weight base class, KNNTestCase. Signed-off-by: John Mazanec --- ...RouteDecisionInfoTransportActionTests.java | 120 ++++++------------ 1 file changed, 36 insertions(+), 84 deletions(-) diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java index 9f64afebb..d07be2070 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java @@ -11,98 +11,50 @@ package org.opensearch.knn.plugin.transport; -import org.junit.After; -import org.junit.Before; -import org.opensearch.core.action.ActionListener; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.knn.KNNSingleNodeTestCase; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; -import org.opensearch.knn.training.TrainingJob; +import org.mockito.MockedStatic; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.concurrent.*; +import java.util.Collections; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL; - -public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNSingleNodeTestCase { - - ExecutorService executorService; - - @Before - public void setup() { - executorService = Executors.newSingleThreadExecutor(); - } - - @After - public void teardown() { - executorService.shutdown(); - } - - @SuppressWarnings("unchecked") - public void testNodeOperation() throws IOException, InterruptedException, ExecutionException { - // Ensure initial value of train job count is 0 - TrainingJobRouteDecisionInfoTransportAction action = node().injector() - .getInstance(TrainingJobRouteDecisionInfoTransportAction.class); - - TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest(); - - TrainingJobRouteDecisionInfoNodeResponse response1 = action.nodeOperation(request); - assertEquals(0, response1.getTrainingJobCount().intValue()); - - // Setup mocked training job - String modelId = "model-id"; - Model model = mock(Model.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); - when(model.getModelMetadata()).thenReturn(modelMetadata); - TrainingJob trainingJob = mock(TrainingJob.class); - when(trainingJob.getModelId()).thenReturn(modelId); - when(trainingJob.getModel()).thenReturn(model); - doAnswer(invocationOnMock -> null).when(trainingJob).run(); - - ModelDao modelDao = mock(ModelDao.class); - when(modelDao.get(modelId)).thenReturn(model); - - // Here we check to make sure there is a running job - doAnswer(invocationOnMock -> { - TrainingJobRouteDecisionInfoNodeResponse response2 = action.nodeOperation(request); - assertEquals(1, response2.getTrainingJobCount().intValue()); - - IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true); - ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); - return null; - }).when(modelDao).put(any(Model.class), any(ActionListener.class)); - - // Set up the rest of the training logic - final CountDownLatch inProgressLatch = new CountDownLatch(1); - ActionListener responseListener = ActionListener.wrap( - indexResponse -> { inProgressLatch.countDown(); }, - e -> fail("Failure should not have occurred") - ); - - doAnswer(invocationOnMock -> { - responseListener.onResponse(mock(IndexResponse.class)); - return null; - }).when(modelDao).update(model, responseListener); +public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNTestCase { + public void testNodeOperation() { + // Initialize mocked variables for the class + DiscoveryNode node = mock(DiscoveryNode.class); + when(clusterService.localNode()).thenReturn(node); ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService); - - // Initialize runner and execute job - TrainingJobRunner.initialize(threadPool, modelDao); - TrainingJobRunner.getInstance().execute(trainingJob, responseListener); - - assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + TransportService transportService = mock(TransportService.class); + doNothing().when(transportService).registerRequestHandler(any(), any(), any(), any()); + ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); + + TrainingJobRouteDecisionInfoTransportAction trainingJobRouteDecisionInfoTransportAction = + new TrainingJobRouteDecisionInfoTransportAction(threadPool, clusterService, transportService, actionFilters); + + try (MockedStatic mockedTrainingJobRunnerStatic = mockStatic(TrainingJobRunner.class)) { + // Ensure the job count is correct + int initialJobCount = 4; + final TrainingJobRunner mockedTrainingJobRunner = mock(TrainingJobRunner.class); + when(mockedTrainingJobRunner.getJobCount()).thenReturn(initialJobCount); + mockedTrainingJobRunnerStatic.when(TrainingJobRunner::getInstance).thenReturn(mockedTrainingJobRunner); + + TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest(); + TrainingJobRouteDecisionInfoNodeResponse response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request); + assertEquals(initialJobCount, response.getTrainingJobCount().intValue()); + + int resetJobCount = 0; + when(mockedTrainingJobRunner.getJobCount()).thenReturn(resetJobCount); + response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request); + assertEquals(resetJobCount, response.getTrainingJobCount().intValue()); + } } }