diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index 58c1bfc251..928a18fcd6 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -557,7 +557,6 @@ private RootAggregateChange createInsertChange(T instance) { } private RootAggregateChange createUpdateChange(EntityAndPreviousVersion entityAndVersion) { - RootAggregateChange aggregateChange = MutableAggregateChange.forSave(entityAndVersion.entity, entityAndVersion.version); new RelationalEntityUpdateWriter(context).write(entityAndVersion.entity, aggregateChange); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/mapping/IdGeneratingBeforeSaveCallback.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/mapping/IdGeneratingBeforeSaveCallback.java index f40a8abeb3..ce8aefd87e 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/mapping/IdGeneratingBeforeSaveCallback.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/mapping/IdGeneratingBeforeSaveCallback.java @@ -2,6 +2,7 @@ import java.util.Map; import java.util.Optional; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.data.jdbc.repository.config.AbstractJdbcConfiguration; @@ -16,59 +17,73 @@ import org.springframework.util.Assert; /** - * Callback for generating ID via the database sequence. By default, it is registered as a - * bean in {@link AbstractJdbcConfiguration} + * Callback for generating ID via the database sequence. By default, it is registered as a bean in + * {@link AbstractJdbcConfiguration} * * @author Mikhail Polivakha */ public class IdGeneratingBeforeSaveCallback implements BeforeSaveCallback { - private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class); - - private final RelationalMappingContext relationalMappingContext; - private final Dialect dialect; - private final NamedParameterJdbcOperations operations; - - public IdGeneratingBeforeSaveCallback( - RelationalMappingContext relationalMappingContext, - Dialect dialect, - NamedParameterJdbcOperations namedParameterJdbcOperations - ) { - this.relationalMappingContext = relationalMappingContext; - this.dialect = dialect; - this.operations = namedParameterJdbcOperations; - } - - @Override - public Object onBeforeSave(Object aggregate, MutableAggregateChange aggregateChange) { - - Assert.notNull(aggregate, "The aggregate cannot be null at this point"); - - RelationalPersistentEntity persistentEntity = relationalMappingContext.getPersistentEntity(aggregate.getClass()); - Optional idSequence = persistentEntity.getIdSequence(); - - if (dialect.getIdGeneration().sequencesSupported()) { - - if (persistentEntity.getIdProperty() != null) { - idSequence - .map(s -> dialect.getIdGeneration().createSequenceQuery(s)) - .ifPresent(sql -> { - Long idValue = operations.queryForObject(sql, Map.of(), (rs, rowNum) -> rs.getLong(1)); - PersistentPropertyAccessor propertyAccessor = persistentEntity.getPropertyAccessor(aggregate); - propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), idValue); - }); - } - } else { - if (idSequence.isPresent()) { - LOG.warn(""" - It seems you're trying to insert an aggregate of type '%s' annotated with @TargetSequence, but the problem is RDBMS you're - working with does not support sequences as such. Falling back to identity columns - """ - .formatted(aggregate.getClass().getName()) - ); - } - } - - return aggregate; - } + private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class); + + private final RelationalMappingContext relationalMappingContext; + private final Dialect dialect; + private final NamedParameterJdbcOperations operations; + + public IdGeneratingBeforeSaveCallback(RelationalMappingContext relationalMappingContext, Dialect dialect, + NamedParameterJdbcOperations namedParameterJdbcOperations) { + this.relationalMappingContext = relationalMappingContext; + this.dialect = dialect; + this.operations = namedParameterJdbcOperations; + } + + @Override + public Object onBeforeSave(Object aggregate, MutableAggregateChange aggregateChange) { + + Assert.notNull(aggregate, "The aggregate cannot be null at this point"); + + RelationalPersistentEntity persistentEntity = relationalMappingContext.getPersistentEntity(aggregate.getClass()); + + if (!persistentEntity.hasIdProperty()) { + return aggregate; + } + + // we're doing INSERT and ID property value is not set explicitly by client + if (persistentEntity.isNew(aggregate) && !hasIdentifierValue(aggregate, persistentEntity)) { + return potentiallyFetchIdFromSequence(aggregate, persistentEntity); + } else { + return aggregate; + } + } + + private boolean hasIdentifierValue(Object aggregate, RelationalPersistentEntity persistentEntity) { + Object identifier = persistentEntity.getIdentifierAccessor(aggregate).getIdentifier(); + + if (persistentEntity.getIdProperty().getType().isPrimitive()) { + return identifier instanceof Number num && num.longValue() != 0L; + } else { + return identifier != null; + } + } + + private Object potentiallyFetchIdFromSequence(Object aggregate, RelationalPersistentEntity persistentEntity) { + Optional idSequence = persistentEntity.getIdSequence(); + + if (dialect.getIdGeneration().sequencesSupported()) { + idSequence.map(s -> dialect.getIdGeneration().createSequenceQuery(s)).ifPresent(sql -> { + Long idValue = operations.queryForObject(sql, Map.of(), (rs, rowNum) -> rs.getLong(1)); + PersistentPropertyAccessor propertyAccessor = persistentEntity.getPropertyAccessor(aggregate); + propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), idValue); + }); + } else { + if (idSequence.isPresent()) { + LOG.warn(""" + It seems you're trying to insert an aggregate of type '%s' annotated with @TargetSequence, but the problem is RDBMS you're + working with does not support sequences as such. Falling back to identity columns + """.formatted(aggregate.getClass().getName())); + } + } + + return aggregate; + } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AbstractJdbcRepositoryLookUpStrategyTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AbstractJdbcRepositoryLookUpStrategyTests.java index 26e225a46a..2e01720094 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AbstractJdbcRepositoryLookUpStrategyTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AbstractJdbcRepositoryLookUpStrategyTests.java @@ -28,6 +28,7 @@ import org.springframework.data.jdbc.testing.EnabledOnDatabase; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.repository.CrudRepository; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; /** @@ -40,7 +41,7 @@ abstract class AbstractJdbcRepositoryLookUpStrategyTests { @Autowired protected OnesRepository onesRepository; - @Autowired NamedParameterJdbcTemplate template; + @Autowired NamedParameterJdbcOperations template; @Autowired RelationalMappingContext context; void insertTestInstances() { diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java index bb2b8a56c3..c3cc88b753 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java @@ -15,13 +15,14 @@ */ package org.springframework.data.jdbc.repository; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; import java.util.List; -import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; @@ -29,31 +30,53 @@ import org.springframework.context.annotation.FilterType; import org.springframework.context.annotation.Import; import org.springframework.data.annotation.Id; +import org.springframework.data.annotation.PersistenceCreator; +import org.springframework.data.annotation.Transient; +import org.springframework.data.domain.Persistable; +import org.springframework.data.jdbc.core.mapping.IdGeneratingBeforeSaveCallback; import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories; import org.springframework.data.jdbc.repository.support.SimpleJdbcRepository; import org.springframework.data.jdbc.testing.IntegrationTest; import org.springframework.data.jdbc.testing.TestConfiguration; +import org.springframework.data.relational.core.conversion.MutableAggregateChange; import org.springframework.data.relational.core.mapping.NamingStrategy; +import org.springframework.data.relational.core.mapping.Sequence; import org.springframework.data.relational.core.mapping.event.BeforeConvertCallback; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.ListCrudRepository; -import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.springframework.test.context.jdbc.Sql; /** * Testing special cases for id generation with {@link SimpleJdbcRepository}. * * @author Jens Schauder * @author Greg Turnquist + * @author Mikhail Polivakha */ @IntegrationTest class JdbcRepositoryIdGenerationIntegrationTests { - @Autowired NamedParameterJdbcTemplate template; - @Autowired ReadOnlyIdEntityRepository readOnlyIdRepository; - @Autowired PrimitiveIdEntityRepository primitiveIdRepository; - @Autowired ImmutableWithManualIdEntityRepository immutableWithManualIdEntityRepository; + @Autowired + ReadOnlyIdEntityRepository readOnlyIdRepository; + @Autowired + PrimitiveIdEntityRepository primitiveIdRepository; + @Autowired + ImmutableWithManualIdEntityRepository immutableWithManualIdEntityRepository; - @Test // DATAJDBC-98 + @Autowired + SimpleSeqRepository simpleSeqRepository; + + @Autowired + PersistableSeqRepository persistableSeqRepository; + + @Autowired + PrimitiveIdSeqRepository primitiveIdSeqRepository; + + @Autowired + IdGeneratingBeforeSaveCallback idGeneratingCallback; + + @Test + // DATAJDBC-98 void idWithoutSetterGetsSet() { ReadOnlyIdEntity entity = readOnlyIdRepository.save(new ReadOnlyIdEntity(null, "Entity Name")); @@ -67,7 +90,8 @@ void idWithoutSetterGetsSet() { }); } - @Test // DATAJDBC-98 + @Test + // DATAJDBC-98 void primitiveIdGetsSet() { PrimitiveIdEntity entity = new PrimitiveIdEntity(); @@ -84,7 +108,8 @@ void primitiveIdGetsSet() { }); } - @Test // DATAJDBC-393 + @Test + // DATAJDBC-393 void manuallyGeneratedId() { ImmutableWithManualIdEntity entity = new ImmutableWithManualIdEntity(null, "immutable"); @@ -95,7 +120,8 @@ void manuallyGeneratedId() { assertThat(immutableWithManualIdEntityRepository.findAll()).hasSize(1); } - @Test // DATAJDBC-393 + @Test + // DATAJDBC-393 void manuallyGeneratedIdForSaveAll() { ImmutableWithManualIdEntity one = new ImmutableWithManualIdEntity(null, "one"); @@ -107,18 +133,146 @@ void manuallyGeneratedIdForSaveAll() { assertThat(immutableWithManualIdEntityRepository.findAll()).hasSize(2); } - private interface PrimitiveIdEntityRepository extends ListCrudRepository {} + @Test // DATAJDBC-2003 + @Sql(statements = "INSERT INTO SimpleSeq(id, name) VALUES(1, 'initial value');") + void testUpdateAggregateWithSequence() { + + SimpleSeq entity = new SimpleSeq(); + entity.id = 1L; + entity.name = "New name"; + AtomicReference afterCallback = mockIdGeneratingCallback(entity); + + SimpleSeq updated = simpleSeqRepository.save(entity); + + assertThat(updated.id).isEqualTo(1L); + assertThat(afterCallback.get()).isSameAs(entity); + assertThat(afterCallback.get().id).isEqualTo(1L); + } + + @Test + // DATAJDBC-2003 + void testInsertPersistableAggregateWithSequenceClientIdIsFavored() { + + long initialId = 1L; + PersistableSeq entityWithSeq = PersistableSeq.createNew(initialId, "name"); + AtomicReference afterCallback = mockIdGeneratingCallback(entityWithSeq); + + PersistableSeq saved = persistableSeqRepository.save(entityWithSeq); + + // We do not expect the SELECT next value from sequence in case we're doing an INSERT with ID provided by the client + assertThat(saved.getId()).isEqualTo(initialId); + assertThat(afterCallback.get()).isSameAs(entityWithSeq); + } + + @Test + // DATAJDBC-2003 + void testInsertAggregateWithSequenceAndUnsetPrimitiveId() { - private interface ReadOnlyIdEntityRepository extends ListCrudRepository {} + PrimitiveIdSeq entity = new PrimitiveIdSeq(); + entity.name = "some name"; + AtomicReference afterCallback = mockIdGeneratingCallback(entity); - private interface ImmutableWithManualIdEntityRepository extends ListCrudRepository {} + PrimitiveIdSeq saved = primitiveIdSeqRepository.save(entity); + + // 1. Select from sequence + // 2. Actual INSERT + assertThat(afterCallback.get().id).isEqualTo(1L); + assertThat(saved.id).isEqualTo(1L); // sequence starts with 1 + } + + @SuppressWarnings("unchecked") + private AtomicReference mockIdGeneratingCallback(T entity) { + AtomicReference afterCallback = new AtomicReference<>(); + Mockito + .doAnswer(invocationOnMock -> { + afterCallback.set((T) invocationOnMock.callRealMethod()); + return afterCallback.get(); + }) + .when(idGeneratingCallback) + .onBeforeSave(Mockito.eq(entity), Mockito.any(MutableAggregateChange.class)); + return afterCallback; + } + + interface PrimitiveIdEntityRepository extends ListCrudRepository { + } + + interface ReadOnlyIdEntityRepository extends ListCrudRepository { + } + + interface ImmutableWithManualIdEntityRepository extends ListCrudRepository { + } + + interface SimpleSeqRepository extends ListCrudRepository { + } + + interface PersistableSeqRepository extends ListCrudRepository { + } + + interface PrimitiveIdSeqRepository extends ListCrudRepository { + } record ReadOnlyIdEntity(@Id Long id, String name) { } + static class SimpleSeq { + + @Id + @Sequence(value = "simple_seq_seq") + private Long id; + + private String name; + } + + static class PersistableSeq implements Persistable { + + @Id + @Sequence(value = "persistable_seq_seq") + private Long id; + + private String name; + + @Transient + private boolean isNew; + + @PersistenceCreator + public PersistableSeq() { + } + + public PersistableSeq(Long id, String name, boolean isNew) { + this.id = id; + this.name = name; + this.isNew = isNew; + } + + static PersistableSeq createNew(Long id, String name) { + return new PersistableSeq(id, name, true); + } + + @Override + public Long getId() { + return id; + } + + @Override + public boolean isNew() { + return isNew; + } + } + + static class PrimitiveIdSeq { + + @Id + @Sequence(value = "primitive_seq_seq") + private long id; + + private String name; + + } + static class PrimitiveIdEntity { - @Id private long id; + @Id + private long id; String name; public long getId() { @@ -142,17 +296,17 @@ record ImmutableWithManualIdEntity(@Id Long id, String name) { @Override public Long id() { - return this.id; - } + return this.id; + } - public ImmutableWithManualIdEntity withId(Long id) { - return this.id == id ? this : new ImmutableWithManualIdEntity(id, this.name); - } + public ImmutableWithManualIdEntity withId(Long id) { + return this.id == id ? this : new ImmutableWithManualIdEntity(id, this.name); + } - public ImmutableWithManualIdEntity withName(String name) { - return this.name == name ? this : new ImmutableWithManualIdEntity(this.id, name); - } + public ImmutableWithManualIdEntity withName(String name) { + return this.name == name ? this : new ImmutableWithManualIdEntity(this.id, name); } + } @Configuration @EnableJdbcRepositories(considerNestedRepositories = true, diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java index 63db08a0cc..e6ffebf370 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java @@ -23,6 +23,7 @@ import javax.sql.DataSource; import org.apache.ibatis.session.SqlSessionFactory; +import org.mockito.Mockito; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -68,6 +69,7 @@ * @author Christoph Strobl * @author Chirag Tailor * @author Christopher Klein + * @author Mikhail Polivakha */ @Configuration @ComponentScan // To pick up configuration classes (per activated profile) @@ -76,24 +78,25 @@ public class TestConfiguration { public static final String PROFILE_SINGLE_QUERY_LOADING = "singleQueryLoading"; public static final String PROFILE_NO_SINGLE_QUERY_LOADING = "!" + PROFILE_SINGLE_QUERY_LOADING; - @Autowired DataSource dataSource; - @Autowired BeanFactory beanFactory; - @Autowired ApplicationEventPublisher publisher; - @Autowired(required = false) SqlSessionFactory sqlSessionFactory; + @Autowired + DataSource dataSource; + @Autowired + BeanFactory beanFactory; + @Autowired + ApplicationEventPublisher publisher; + @Autowired(required = false) + SqlSessionFactory sqlSessionFactory; @Bean JdbcRepositoryFactory jdbcRepositoryFactory( @Qualifier("defaultDataAccessStrategy") DataAccessStrategy dataAccessStrategy, RelationalMappingContext context, Dialect dialect, JdbcConverter converter, Optional> namedQueries, - List> callbacks, - List evaulationContextExtensions) { + List> callbacks, List evaulationContextExtensions) { JdbcRepositoryFactory factory = new JdbcRepositoryFactory(dataAccessStrategy, context, converter, dialect, publisher, namedParameterJdbcTemplate()); - factory.setEntityCallbacks( - EntityCallbacks.create(callbacks.toArray(new EntityCallback[0])) - ); + factory.setEntityCallbacks(EntityCallbacks.create(callbacks.toArray(new EntityCallback[0]))); namedQueries.map(it -> it.iterator().next()).ifPresent(factory::setNamedQueries); @@ -118,22 +121,24 @@ DataAccessStrategy defaultDataAccessStrategy( @Qualifier("namedParameterJdbcTemplate") NamedParameterJdbcOperations template, RelationalMappingContext context, JdbcConverter converter, Dialect dialect) { - return new DataAccessStrategyFactory(new SqlGeneratorSource(context, converter, dialect), converter, - template, new SqlParametersFactory(context, converter), - new InsertStrategyFactory(template, dialect)).create(); + return new DataAccessStrategyFactory(new SqlGeneratorSource(context, converter, dialect), converter, template, + new SqlParametersFactory(context, converter), new InsertStrategyFactory(template, dialect)).create(); } @Bean("jdbcMappingContext") @Profile(PROFILE_NO_SINGLE_QUERY_LOADING) - JdbcMappingContext jdbcMappingContextWithOutSingleQueryLoading(Optional namingStrategy, CustomConversions conversions) { + JdbcMappingContext jdbcMappingContextWithOutSingleQueryLoading(Optional namingStrategy, + CustomConversions conversions) { JdbcMappingContext mappingContext = new JdbcMappingContext(namingStrategy.orElse(DefaultNamingStrategy.INSTANCE)); mappingContext.setSimpleTypeHolder(conversions.getSimpleTypeHolder()); return mappingContext; } + @Bean("jdbcMappingContext") @Profile(PROFILE_SINGLE_QUERY_LOADING) - JdbcMappingContext jdbcMappingContextWithSingleQueryLoading(Optional namingStrategy, CustomConversions conversions) { + JdbcMappingContext jdbcMappingContextWithSingleQueryLoading(Optional namingStrategy, + CustomConversions conversions) { JdbcMappingContext mappingContext = new JdbcMappingContext(namingStrategy.orElse(DefaultNamingStrategy.INSTANCE)); mappingContext.setSimpleTypeHolder(conversions.getSimpleTypeHolder()); @@ -144,8 +149,9 @@ JdbcMappingContext jdbcMappingContextWithSingleQueryLoading(Optional