Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Refactor TrainingJobRouteDecisionInfo test to use KNNTestCase #1375

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,98 +11,50 @@

package org.opensearch.knn.plugin.transport;

import org.junit.After;
import org.junit.Before;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.knn.KNNSingleNodeTestCase;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.training.TrainingJob;
import org.mockito.MockedStatic;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.concurrent.*;
import java.util.Collections;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL;

public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNSingleNodeTestCase {

ExecutorService executorService;

@Before
public void setup() {
executorService = Executors.newSingleThreadExecutor();
}

@After
public void teardown() {
executorService.shutdown();
}

@SuppressWarnings("unchecked")
public void testNodeOperation() throws IOException, InterruptedException, ExecutionException {
// Ensure initial value of train job count is 0
TrainingJobRouteDecisionInfoTransportAction action = node().injector()
.getInstance(TrainingJobRouteDecisionInfoTransportAction.class);

TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest();

TrainingJobRouteDecisionInfoNodeResponse response1 = action.nodeOperation(request);
assertEquals(0, response1.getTrainingJobCount().intValue());

// Setup mocked training job
String modelId = "model-id";
Model model = mock(Model.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);
when(model.getModelMetadata()).thenReturn(modelMetadata);
TrainingJob trainingJob = mock(TrainingJob.class);
when(trainingJob.getModelId()).thenReturn(modelId);
when(trainingJob.getModel()).thenReturn(model);
doAnswer(invocationOnMock -> null).when(trainingJob).run();

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.get(modelId)).thenReturn(model);

// Here we check to make sure there is a running job
doAnswer(invocationOnMock -> {
TrainingJobRouteDecisionInfoNodeResponse response2 = action.nodeOperation(request);
assertEquals(1, response2.getTrainingJobCount().intValue());

IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true);
((ActionListener<IndexResponse>) invocationOnMock.getArguments()[1]).onResponse(indexResponse);
return null;
}).when(modelDao).put(any(Model.class), any(ActionListener.class));

// Set up the rest of the training logic
final CountDownLatch inProgressLatch = new CountDownLatch(1);
ActionListener<IndexResponse> responseListener = ActionListener.wrap(
indexResponse -> { inProgressLatch.countDown(); },
e -> fail("Failure should not have occurred")
);

doAnswer(invocationOnMock -> {
responseListener.onResponse(mock(IndexResponse.class));
return null;
}).when(modelDao).update(model, responseListener);

public class TrainingJobRouteDecisionInfoTransportActionTests extends KNNTestCase {
public void testNodeOperation() {
// Initialize mocked variables for the class
DiscoveryNode node = mock(DiscoveryNode.class);
when(clusterService.localNode()).thenReturn(node);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService);

// Initialize runner and execute job
TrainingJobRunner.initialize(threadPool, modelDao);
TrainingJobRunner.getInstance().execute(trainingJob, responseListener);

assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS));
TransportService transportService = mock(TransportService.class);
doNothing().when(transportService).registerRequestHandler(any(), any(), any(), any());
ActionFilters actionFilters = new ActionFilters(Collections.emptySet());

TrainingJobRouteDecisionInfoTransportAction trainingJobRouteDecisionInfoTransportAction =
new TrainingJobRouteDecisionInfoTransportAction(threadPool, clusterService, transportService, actionFilters);

try (MockedStatic<TrainingJobRunner> mockedTrainingJobRunnerStatic = mockStatic(TrainingJobRunner.class)) {
// Ensure the job count is correct
int initialJobCount = 4;
final TrainingJobRunner mockedTrainingJobRunner = mock(TrainingJobRunner.class);
when(mockedTrainingJobRunner.getJobCount()).thenReturn(initialJobCount);
mockedTrainingJobRunnerStatic.when(TrainingJobRunner::getInstance).thenReturn(mockedTrainingJobRunner);

TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest();
TrainingJobRouteDecisionInfoNodeResponse response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request);
assertEquals(initialJobCount, response.getTrainingJobCount().intValue());

int resetJobCount = 0;
when(mockedTrainingJobRunner.getJobCount()).thenReturn(resetJobCount);
response = trainingJobRouteDecisionInfoTransportAction.nodeOperation(request);
assertEquals(resetJobCount, response.getTrainingJobCount().intValue());
}
}
}
Loading