diff --git a/CHANGELOG.md b/CHANGELOG.md index d0058d3f1739e..3658fa04699ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Removed ### Fixed +- Fix illegal argument exception when creating a PIT ([#16781](https://github.com/opensearch-project/OpenSearch/pull/16781)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 0890e9f5de8d4..351c23fec3d80 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Optional; /** * Base class for all individual search phases like collecting distributed frequencies, fetching documents, querying shards. @@ -69,11 +70,26 @@ public String getName() { } /** - * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined + * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined. + * @deprecated Use getSearchPhaseNameOptional() to avoid possible exceptions. * in {@link SearchPhaseName} * @return {@link SearchPhaseName} */ + @Deprecated public SearchPhaseName getSearchPhaseName() { return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); } + + /** + * Returns an Optional of the SearchPhase name as {@link SearchPhaseName}. If there's not a matching SearchPhaseName, + * returns an empty Optional. + * @return {@link Optional} + */ + public Optional getSearchPhaseNameOptional() { + try { + return Optional.of(SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT))); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java index 94200d29a4f21..dd3b6838ab5da 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java @@ -73,20 +73,22 @@ public long getTookMetric() { @Override protected void onPhaseStart(SearchPhaseContext context) { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); + context.getCurrentPhase().getSearchPhaseNameOptional().ifPresent(name -> phaseStatsMap.get(name).current.inc()); } @Override protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - StatsHolder phaseStats = phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()); - phaseStats.current.dec(); - phaseStats.total.inc(); - phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos())); + context.getCurrentPhase().getSearchPhaseNameOptional().ifPresent(name -> { + StatsHolder phaseStats = phaseStatsMap.get(name); + phaseStats.current.dec(); + phaseStats.total.inc(); + phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos())); + }); } @Override protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + context.getCurrentPhase().getSearchPhaseNameOptional().ifPresent(name -> phaseStatsMap.get(name).current.dec()); } @Override diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 27336e86e52b0..b0fab3b7a3556 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -399,29 +399,29 @@ public void testOnPhaseFailureAndVerifyListeners() { final List requestOperationListeners = List.of(testListener, assertingListener); SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseNameOptional().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseNameOptional().get())); SearchDfsQueryThenFetchAsyncAction searchDfsQueryThenFetchAsyncAction = createSearchDfsQueryThenFetchAsyncAction( requestOperationListeners ); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseNameOptional().get())); searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseNameOptional().get())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -430,15 +430,15 @@ public void run() { action.skipShard(searchShardIterator); action.start(); action.executeNextPhase(action, fetchPhase); - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseNameOptional().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseNameOptional().get())); } public void testOnPhaseFailure() { @@ -722,7 +722,7 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.start(); // Verify queryPhase current metric - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseNameOptional().get())); TimeUnit.MILLISECONDS.sleep(delay); FetchSearchPhase fetchPhase = createFetchSearchPhase(); @@ -733,12 +733,12 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.executeNextPhase(action, fetchPhase); // Verify queryPhase total, current and latency metrics - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertThat(testListener.getPhaseMetric(action.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseNameOptional().get())); + assertThat(testListener.getPhaseMetric(action.getSearchPhaseNameOptional().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseNameOptional().get())); // Verify fetchPhase current metric - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseNameOptional().get())); TimeUnit.MILLISECONDS.sleep(delay); ExpandSearchPhase expandPhase = createExpandSearchPhase(); @@ -746,18 +746,18 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx TimeUnit.MILLISECONDS.sleep(delay); // Verify fetchPhase total, current and latency metrics - assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseNameOptional().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseNameOptional().get())); - assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseNameOptional().get())); action.executeNextPhase(expandPhase, fetchPhase); action.onPhaseDone(); /* finish phase since we don't have reponse being sent */ - assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseNameOptional().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseNameOptional().get())); } public void testOnPhaseListenersWithDfsType() throws InterruptedException { @@ -772,7 +772,7 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { FetchSearchPhase fetchPhase = createFetchSearchPhase(); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseNameOptional().get())); TimeUnit.MILLISECONDS.sleep(delay); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); SearchShardIterator searchShardIterator = new SearchShardIterator(null, shardId, Collections.emptyList(), OriginalIndices.NONE); @@ -786,9 +786,12 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { null ); /* finalizing the fetch phase since we do adhoc phase lifecycle calls */ - assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertThat( + testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseNameOptional().get()), + greaterThanOrEqualTo(delay) + ); + assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseNameOptional().get())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseNameOptional().get())); } private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java index 990ed95f1aebc..29561e938bf6c 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java @@ -14,6 +14,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -30,18 +31,18 @@ public void testListenersAreExecuted() { @Override public void onPhaseStart(SearchPhaseContext context) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseNameOptional().get()).current.inc(); } @Override public void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).total.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseNameOptional().get()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseNameOptional().get()).total.inc(); } @Override public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseNameOptional().get()).current.dec(); } }; @@ -61,7 +62,7 @@ public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { when(ctx.getCurrentPhase()).thenReturn(searchPhase); - when(searchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(searchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); compositeListener.onPhaseStart(ctx); assertEquals(totalListeners, searchPhaseMap.get(searchPhaseName).current.count()); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java index 3bad3ec3e7d21..7c2a3435afd6d 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java @@ -16,6 +16,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; @@ -68,7 +69,7 @@ public void testSearchRequestPhaseFailure() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); testRequestStats.onPhaseStart(ctx); assertEquals(1, testRequestStats.getPhaseCurrent(searchPhaseName)); testRequestStats.onPhaseFailure(ctx, new Throwable()); @@ -85,7 +86,7 @@ public void testSearchRequestStats() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); testRequestStats.onPhaseStart(ctx); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); @@ -116,7 +117,7 @@ public void testSearchRequestStatsOnPhaseStartConcurrently() throws InterruptedE SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); @@ -145,7 +146,7 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(startTime); @@ -188,7 +189,7 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); @@ -205,4 +206,51 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte assertEquals(0, testRequestStats.getPhaseCurrent(searchPhaseName)); } } + + public void testUnrecognizedPhaseNamesAreIgnored() { + // Unrecognized phase names producing an empty optional should not throw any error and no stats should be incremented. + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings); + SearchPhaseContext ctx = mock(SearchPhaseContext.class); + SearchPhase mockSearchPhase = mock(SearchPhase.class); + when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); + + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.empty()); + testRequestStats.onPhaseStart(ctx); + int minTimeNanos = 10; + long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(minTimeNanos); + when(mockSearchPhase.getStartTimeInNanos()).thenReturn(startTime); + + for (SearchPhaseName name : SearchPhaseName.values()) { + assertEquals(0, testRequestStats.getPhaseCurrent(name)); + } + + testRequestStats.onPhaseEnd( + ctx, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequest(), + () -> null + ) + ); + + for (SearchPhaseName name : SearchPhaseName.values()) { + assertEquals(0, testRequestStats.getPhaseCurrent(name)); + assertEquals(0, testRequestStats.getPhaseTotal(name)); + assertEquals(0, testRequestStats.getPhaseMetric(name)); + } + } + + public void testUnrecognizedSearchPhaseReturnsEmptyOptional() { + // Test search phases with unrecognized names return Optional.empty() when getSearchPhaseNameOptional() is called. + // These may exist, for example, "create_pit". + String unrecognizedName = "unrecognized_name"; + SearchPhase dummyPhase = new SearchPhase(unrecognizedName) { + @Override + public void run() {} + }; + + assertEquals(unrecognizedName, dummyPhase.getName()); + assertEquals(Optional.empty(), dummyPhase.getSearchPhaseNameOptional()); + } } diff --git a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java index 594700ea60b3e..65e8997d75403 100644 --- a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java +++ b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java @@ -44,6 +44,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -86,7 +87,7 @@ public void testShardLevelSearchGroupStats() throws Exception { SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(System.nanoTime() - TimeUnit.SECONDS.toNanos(paramValue)); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseNameOptional()).thenReturn(Optional.of(searchPhaseName)); for (int iterator = 0; iterator < paramValue; iterator++) { onPhaseStart(testRequestStats, ctx); onPhaseEnd(testRequestStats, ctx);