diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java index fab23ada7e..1a155d6d96 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java @@ -16,6 +16,9 @@ package org.springframework.batch.core; +import java.util.Arrays; +import java.util.List; + /** * Enumeration representing the status of an Execution. * @@ -39,6 +42,8 @@ public enum BatchStatus { */ COMPLETED, STARTING, STARTED, STOPPING, STOPPED, FAILED, ABANDONED, UNKNOWN; + public static final List RUNNING_STATUSES = Arrays.asList(STARTING, STARTED); + public static BatchStatus max(BatchStatus status1, BatchStatus status2) { return status1.isGreaterThan(status2) ? status1 : status2; } @@ -49,7 +54,7 @@ public static BatchStatus max(BatchStatus status1, BatchStatus status2) { * @return true if the status is STARTING, STARTED */ public boolean isRunning() { - return this == STARTING || this == STARTED; + return RUNNING_STATUSES.contains(this); } /** diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java index 586ad67be1..fa636accbe 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java @@ -15,9 +15,7 @@ */ package org.springframework.batch.core.explore; -import java.util.List; -import java.util.Set; - +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -25,6 +23,10 @@ import org.springframework.batch.item.ExecutionContext; import org.springframework.lang.Nullable; +import java.util.Collection; +import java.util.List; +import java.util.Set; + /** * Entry point for browsing executions of running or historical jobs and steps. * Since the data may be re-hydrated from persistent storage, it may not contain @@ -89,6 +91,14 @@ default JobInstance getLastJobInstance(String jobName) { @Nullable StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable Long stepExecutionId); + /** + * Retrieve number of step executions that match the step execution ids and the batch statuses + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses given batch statuses to match against + * @return number of {@link StepExecution} matching the criteria + */ + int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses); + /** * @param instanceId {@link Long} id for the jobInstance to obtain. * @return the {@link JobInstance} with this id, or null @@ -164,4 +174,11 @@ default JobExecution getLastJobExecution(JobInstance jobInstance) { */ int getJobInstanceCount(@Nullable String jobName) throws NoSuchJobException; + /** + * Find step executions in bulk + * @param jobExecutionId given job execution id + * @param stepExecutionIds given step execution ids + * @return collection of {@link StepExecution} + */ + Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds); } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java index 81a1d9c333..f3b80c5515 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java @@ -16,6 +16,7 @@ package org.springframework.batch.core.explore.support; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -27,6 +28,7 @@ import org.springframework.batch.core.repository.dao.StepExecutionDao; import org.springframework.lang.Nullable; +import java.util.Collection; import java.util.List; import java.util.Set; @@ -165,6 +167,14 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L return stepExecution; } + @Override + public int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses) { + if (stepExecutionIds.isEmpty() || matchingBatchStatuses.isEmpty()) { + return 0; + } + return stepExecutionDao.countStepExecutions(stepExecutionIds, matchingBatchStatuses); + } + /* * (non-Javadoc) * @@ -221,6 +231,19 @@ public int getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcepti return jobInstanceDao.getJobInstanceCount(jobName); } + @Override + @Nullable + public Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds) { + JobExecution jobExecution = jobExecutionDao.getJobExecution(jobExecutionId); + if (jobExecution == null) { + return null; + } + getJobExecutionDependencies(jobExecution); + Collection stepExecutions = stepExecutionDao.getStepExecutions(jobExecution, stepExecutionIds); + stepExecutions.forEach(this::getStepExecutionDependencies); + return stepExecutions; + } + /* * Find all dependencies for a JobExecution, including JobInstance (which * requires JobParameters) plus StepExecutions diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java index d5712fa227..c63c9be6c4 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java @@ -16,25 +16,9 @@ package org.springframework.batch.core.repository.dao; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - -import org.springframework.batch.core.BatchStatus; -import org.springframework.batch.core.ExitStatus; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.beans.factory.InitializingBean; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.jdbc.core.BatchPreparedStatementSetter; @@ -43,6 +27,11 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.sql.*; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + /** * JDBC implementation of {@link StepExecutionDao}.
* @@ -114,6 +103,15 @@ public class JdbcStepExecutionDao extends AbstractJdbcBatchMetadataDao implement " and SE.JOB_EXECUTION_ID = JE.JOB_EXECUTION_ID " + " and SE.STEP_NAME = ?"; + // need to replace the %STEP_EXECUTION_IDS% with a known number of ?s + private static final String GET_STEP_EXECUTIONS_BY_IDS = GET_RAW_STEP_EXECUTIONS + " and STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%)"; + + // need to replace the %STEP_EXECUTION_IDS% and %STEP_STATUSES% with a known number of ?s + private static final String COUNT_STEP_EXECUTIONS_MATCHING_IDS_AND_STATUSES = "SELECT COUNT(*) " + + "from %PREFIX%STEP_EXECUTION SE " + + "where SE.STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%) " + + "and SE.STATUS IN (%STEP_STATUSES%)"; + private int exitMessageLength = DEFAULT_EXIT_MESSAGE_LENGTH; private DataFieldMaxValueIncrementer stepExecutionIncrementer; @@ -350,12 +348,31 @@ public StepExecution getLastStepExecution(JobInstance jobInstance, String stepNa } } + @Override + public Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds) { + String sql = createParameterizedQuery(GET_STEP_EXECUTIONS_BY_IDS, "%STEP_EXECUTION_IDS%", stepExecutionIds); + return getJdbcTemplate().query(getQuery(sql), + new StepExecutionRowMapper(jobExecution), + Stream.concat(Stream.of(jobExecution.getId()), stepExecutionIds.stream()).toArray()); + } + @Override public void addStepExecutions(JobExecution jobExecution) { getJdbcTemplate().query(getQuery(GET_STEP_EXECUTIONS), new StepExecutionRowMapper(jobExecution), jobExecution.getId()); } + @Override + public int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + String sql = createParameterizedQuery(COUNT_STEP_EXECUTIONS_MATCHING_IDS_AND_STATUSES, "%STEP_EXECUTION_IDS%", stepExecutionIds); + sql = createParameterizedQuery(sql, "%STEP_STATUSES%", matchingBatchStatuses); + Object[] args = Stream.concat(stepExecutionIds.stream(), + matchingBatchStatuses.stream().map(BatchStatus::name)).toArray(); + return getJdbcTemplate().queryForObject(getQuery(sql), + Integer.class, + args); + } + @Override public int countStepExecutions(JobInstance jobInstance, String stepName) { return getJdbcTemplate().queryForObject(getQuery(COUNT_STEP_EXECUTIONS), new Object[] { jobInstance.getInstanceId(), stepName }, Integer.class); @@ -391,4 +408,17 @@ public StepExecution mapRow(ResultSet rs, int rowNum) throws SQLException { } + /** + * Replaces a given placeholder with a number of parameters (i.e. "?"). + * + * @param sqlTemplate given sql template + * @param placeholder placeholder that is being used for parameters + * @param parameters collection of parameters with variable size + * + * @return sql query replaced with a number of parameters + */ + private static String createParameterizedQuery(String sqlTemplate, String placeholder, Collection parameters) { + String params = parameters.stream().map(p -> "?").collect(Collectors.joining(", ")); + return sqlTemplate.replace(placeholder, params); + } } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java index 2e3bed2466..c74865861d 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java @@ -15,26 +15,19 @@ */ package org.springframework.batch.core.repository.dao; -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import org.springframework.batch.core.Entity; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.util.SerializationUtils; +import java.lang.reflect.Field; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + /** * In-memory implementation of {@link StepExecutionDao}. * @@ -189,4 +182,23 @@ public int countStepExecutions(JobInstance jobInstance, String stepName) { } return count; } + + @Override + public int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + int count = 0; + + for (Long id: stepExecutionIds) { + if (executionsByStepExecutionId.containsKey(id) && matchingBatchStatuses.contains(executionsByStepExecutionId.get(id).getStatus())) { + count++; + } + } + return count; + } + + @Override + public Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds) { + return executionsByStepExecutionId.values().stream() + .filter(se -> stepExecutionIds.contains(se.getId()) && se.getJobExecutionId().equals(jobExecution.getId())) + .collect(Collectors.toList()); + } } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java index 107b44e717..3b451ba601 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java @@ -16,13 +16,14 @@ package org.springframework.batch.core.repository.dao; -import java.util.Collection; - +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; import org.springframework.lang.Nullable; +import java.util.Collection; + public interface StepExecutionDao { /** @@ -86,6 +87,22 @@ default StepExecution getLastStepExecution(JobInstance jobInstance, String stepN */ void addStepExecutions(JobExecution jobExecution); + /** + * Count {@link StepExecution} that match the ids and statuses of them - avoid loading them into memory + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses + * @return + */ + int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses); + + /** + * Get a collection of {@link StepExecution} matching job execution and step execution ids. + * @param jobExecution the parent job execution + * @param stepExecutionIds the step execution ids + * @return collection of {@link StepExecution} + */ + @Nullable + Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds); /** * Counts all the {@link StepExecution} for a given step name. * diff --git a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java index 8b068fc2bf..34e3a0bbb4 100644 --- a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java +++ b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java @@ -15,28 +15,10 @@ */ package org.springframework.batch.core.launch.support; -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Date; -import java.util.HashSet; -import java.util.List; -import java.util.Properties; -import java.util.Set; - import org.junit.After; import org.junit.Before; import org.junit.Test; - -import org.springframework.batch.core.BatchStatus; -import org.springframework.batch.core.ExitStatus; -import org.springframework.batch.core.Job; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.JobParameters; -import org.springframework.batch.core.JobParametersBuilder; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.batch.core.converter.DefaultJobParametersConverter; import org.springframework.batch.core.converter.JobParametersConverter; import org.springframework.batch.core.explore.JobExplorer; @@ -49,6 +31,10 @@ import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; +import java.io.IOException; +import java.io.InputStream; +import java.util.*; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -538,6 +524,11 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L throw new UnsupportedOperationException(); } + @Override + public int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses) { + throw new UnsupportedOperationException(); + } + @Override public List getJobNames() { throw new UnsupportedOperationException(); @@ -566,6 +557,10 @@ public int getJobInstanceCount(@Nullable String jobName) } } + @Override + public Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds) { + throw new UnsupportedOperationException(); + } } public static class StubJobParametersConverter implements JobParametersConverter { diff --git a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java index bac0462b61..58dc69a228 100644 --- a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java +++ b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java @@ -1,19 +1,9 @@ package org.springframework.batch.integration.partition; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Set; -import java.util.concurrent.Callable; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; - -import javax.sql.DataSource; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - +import org.springframework.batch.core.BatchStatus; +import org.springframework.batch.core.Entity; import org.springframework.batch.core.Step; import org.springframework.batch.core.StepExecution; import org.springframework.batch.core.explore.JobExplorer; @@ -37,6 +27,15 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import javax.sql.DataSource; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + /** * A {@link PartitionHandler} that uses {@link MessageChannel} instances to send instructions to remote workers and * receive their responses. The {@link MessageChannel} provides a nice abstraction so that the location of the workers @@ -236,49 +235,31 @@ public Collection handle(StepExecutionSplitter stepExecutionSplit } } - private Collection pollReplies(final StepExecution masterStepExecution, final Set split) throws Exception { - final Collection result = new ArrayList<>(split.size()); - - Callable> callback = new Callable>() { - @Override - public Collection call() throws Exception { + private Collection pollReplies(final StepExecution masterStepExecution, final Set split) throws Exception { + Collection ids = split.stream().map(Entity::getId).collect(Collectors.toList()); - for(Iterator stepExecutionIterator = split.iterator(); stepExecutionIterator.hasNext(); ) { - StepExecution curStepExecution = stepExecutionIterator.next(); - - if(!result.contains(curStepExecution)) { - StepExecution partitionStepExecution = - jobExplorer.getStepExecution(masterStepExecution.getJobExecutionId(), curStepExecution.getId()); - - if(!partitionStepExecution.getStatus().isRunning()) { - result.add(partitionStepExecution); - } - } - } + Callable> callback = () -> { + int runningStepExecutions = jobExplorer.getStepExecutionCount(ids, BatchStatus.RUNNING_STATUSES); + if(runningStepExecutions > 0 && split.size() > 0) { if(logger.isDebugEnabled()) { - logger.debug(String.format("Currently waiting on %s partitions to finish", split.size())); - } - - if(result.size() == split.size()) { - return result; - } - else { - return null; + logger.debug(String.format("Currently waiting on %s out of %s partitions to finish", runningStepExecutions, split.size())); } + return null; + } else { + return jobExplorer.getStepExecutions(masterStepExecution.getJobExecutionId(), ids); } }; - Poller> poller = new DirectPoller<>(pollInterval); - Future> resultsFuture = poller.poll(callback); + Poller> poller = new DirectPoller<>(pollInterval); + Future> resultsFuture = poller.poll(callback); - if(timeout >= 0) { - return resultsFuture.get(timeout, TimeUnit.MILLISECONDS); - } - else { - return resultsFuture.get(); - } - } + if(timeout >= 0) { + return resultsFuture.get(timeout, TimeUnit.MILLISECONDS); + } else { + return resultsFuture.get(); + } + } private Collection receiveReplies(PollableChannel currentReplyChannel) { @SuppressWarnings("unchecked") diff --git a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java index fd9170412f..fa0d83817c 100644 --- a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java +++ b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java @@ -1,12 +1,6 @@ package org.springframework.batch.integration.partition; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.concurrent.TimeoutException; - import org.junit.Test; - import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobParameters; @@ -18,15 +12,16 @@ import org.springframework.messaging.Message; import org.springframework.messaging.PollableChannel; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * @@ -154,8 +149,8 @@ public void testHandleWithJobRepositoryPolling() throws Exception { stepExecutions.add(partition2); stepExecutions.add(partition3); when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); - when(jobExplorer.getStepExecution(eq(5L), any(Long.class))).thenReturn(partition2, partition1, partition3, partition3, partition3, partition3, partition4); - + when(jobExplorer.getStepExecutionCount(any(), any())).thenReturn(3, 2, 0); + when(jobExplorer.getStepExecutions(eq(5L), any())).thenReturn(Arrays.asList(partition1, partition2, partition4)); //set messageChannelPartitionHandler.setMessagingOperations(operations); messageChannelPartitionHandler.setJobExplorer(jobExplorer); @@ -198,7 +193,8 @@ public void testHandleWithJobRepositoryPollingTimeout() throws Exception { stepExecutions.add(partition2); stepExecutions.add(partition3); when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); - when(jobExplorer.getStepExecution(eq(5L), any(Long.class))).thenReturn(partition2, partition1, partition3); + when(jobExplorer.getStepExecutionCount(any(), any())).thenReturn(2); + when(jobExplorer.getStepExecutions(eq(5L), any())).thenReturn(Arrays.asList(partition1, partition2, partition3)); //set messageChannelPartitionHandler.setMessagingOperations(operations);