Skip to content

Commit d1c9960

Browse files
mipo256schauder
authored andcommitted
Support for ID generation by sequence.
Ids can be annotated with @sequence to specify a sequence to pull id values from. Closes #1923 Original pull request #1955 Signed-off-by: mipo256 <[email protected]> Some accidential changes removed. Signed-off-by: schauder <[email protected]>
1 parent b51c77b commit d1c9960

33 files changed

+716
-39
lines changed

Diff for: spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,13 @@ public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domai
118118

119119
Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");
120120
SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream()
121-
.map(insertSubject -> sqlParametersFactory.forInsert(insertSubject.getInstance(), domainType,
122-
insertSubject.getIdentifier(), idValueSource))
121+
.map(insertSubject -> sqlParametersFactory.forInsert( //
122+
insertSubject.getInstance(), //
123+
domainType, //
124+
insertSubject.getIdentifier(), //
125+
idValueSource //
126+
) //
127+
) //
123128
.toArray(SqlIdentifierParameterSource[]::new);
124129

125130
String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
@@ -280,7 +285,8 @@ public <T> List<T> findAll(Class<T> domainType) {
280285

281286
@Override
282287
public <T> Stream<T> streamAll(Class<T> domainType) {
283-
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
288+
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(),
289+
getEntityRowMapper(domainType));
284290
}
285291

286292
@Override
@@ -364,7 +370,8 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
364370

365371
@Override
366372
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
367-
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
373+
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(),
374+
getEntityRowMapper(domainType));
368375
}
369376

370377
@Override
@@ -479,5 +486,4 @@ private Class<?> getBaseType(PersistentPropertyPath<RelationalPersistentProperty
479486

480487
return baseProperty.getOwner().getType();
481488
}
482-
483489
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package org.springframework.data.jdbc.core.mapping;
2+
3+
import java.util.Map;
4+
import java.util.Optional;
5+
import org.apache.commons.logging.Log;
6+
import org.apache.commons.logging.LogFactory;
7+
import org.springframework.data.jdbc.repository.config.AbstractJdbcConfiguration;
8+
import org.springframework.data.mapping.PersistentPropertyAccessor;
9+
import org.springframework.data.relational.core.conversion.MutableAggregateChange;
10+
import org.springframework.data.relational.core.dialect.Dialect;
11+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
12+
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
13+
import org.springframework.data.relational.core.mapping.event.BeforeSaveCallback;
14+
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
15+
import org.springframework.util.Assert;
16+
17+
/**
18+
* Callback for generating ID via the database sequence. By default, it is registered as a
19+
* bean in {@link AbstractJdbcConfiguration}
20+
*
21+
* @author Mikhail Polivakha
22+
*/
23+
public class IdGeneratingBeforeSaveCallback implements BeforeSaveCallback<Object> {
24+
25+
private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class);
26+
27+
private final RelationalMappingContext relationalMappingContext;
28+
private final Dialect dialect;
29+
private final NamedParameterJdbcOperations operations;
30+
31+
public IdGeneratingBeforeSaveCallback(
32+
RelationalMappingContext relationalMappingContext,
33+
Dialect dialect,
34+
NamedParameterJdbcOperations namedParameterJdbcOperations
35+
) {
36+
this.relationalMappingContext = relationalMappingContext;
37+
this.dialect = dialect;
38+
this.operations = namedParameterJdbcOperations;
39+
}
40+
41+
@Override
42+
public Object onBeforeSave(Object aggregate, MutableAggregateChange<Object> aggregateChange) {
43+
Assert.notNull(aggregate, "The aggregate cannot be null at this point");
44+
RelationalPersistentEntity<?> persistentEntity = relationalMappingContext.getPersistentEntity(aggregate.getClass());
45+
Optional<String> idTargetSequence = persistentEntity.getIdTargetSequence();
46+
47+
if (dialect.getIdGeneration().sequencesSupported()) {
48+
49+
if (persistentEntity.getIdProperty() != null) {
50+
idTargetSequence
51+
.map(s -> dialect.getIdGeneration().nextValueFromSequenceSelect(s))
52+
.ifPresent(sql -> {
53+
Long idValue = operations.queryForObject(sql, Map.of(), (rs, rowNum) -> rs.getLong(1));
54+
PersistentPropertyAccessor<Object> propertyAccessor = persistentEntity.getPropertyAccessor(aggregate);
55+
propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), idValue);
56+
});
57+
}
58+
} else {
59+
if (idTargetSequence.isPresent()) {
60+
LOG.warn("""
61+
It seems you're trying to insert an aggregate of type '%s' annotated with @TargetSequence, but the problem is RDBMS you're
62+
working with does not support sequences as such. Falling back to identity columns
63+
"""
64+
.formatted(aggregate.getClass().getName())
65+
);
66+
}
67+
}
68+
69+
return aggregate;
70+
}
71+
}

Diff for: spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/config/AbstractJdbcConfiguration.java

+17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.data.jdbc.core.JdbcAggregateTemplate;
3939
import org.springframework.data.jdbc.core.convert.*;
4040
import org.springframework.data.jdbc.core.dialect.JdbcDialect;
41+
import org.springframework.data.jdbc.core.mapping.IdGeneratingBeforeSaveCallback;
4142
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
4243
import org.springframework.data.jdbc.core.mapping.JdbcSimpleTypes;
4344
import org.springframework.data.mapping.model.SimpleTypeHolder;
@@ -119,6 +120,22 @@ public JdbcMappingContext jdbcMappingContext(Optional<NamingStrategy> namingStra
119120
return mappingContext;
120121
}
121122

123+
/**
124+
* Creates a {@link IdGeneratingBeforeSaveCallback} bean using the configured
125+
* {@link #jdbcMappingContext(Optional, JdbcCustomConversions, RelationalManagedTypes)} and
126+
* {@link #jdbcDialect(NamedParameterJdbcOperations)}.
127+
*
128+
* @return must not be {@literal null}.
129+
*/
130+
@Bean
131+
public IdGeneratingBeforeSaveCallback idGeneratingBeforeSaveCallback(
132+
JdbcMappingContext mappingContext,
133+
NamedParameterJdbcOperations operations,
134+
Dialect dialect
135+
) {
136+
return new IdGeneratingBeforeSaveCallback(mappingContext, dialect, operations);
137+
}
138+
122139
/**
123140
* Creates a {@link RelationalConverter} using the configured
124141
* {@link #jdbcMappingContext(Optional, JdbcCustomConversions, RelationalManagedTypes)}.

Diff for: spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java

-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.springframework.data.convert.WritingConverter;
3434
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
3535
import org.springframework.data.relational.core.conversion.IdValueSource;
36-
import org.springframework.data.relational.core.dialect.AnsiDialect;
3736
import org.springframework.data.relational.core.mapping.Column;
3837
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
3938
import org.springframework.data.relational.core.sql.SqlIdentifier;
@@ -49,7 +48,6 @@ class SqlParametersFactoryTest {
4948
RelationalMappingContext context = new JdbcMappingContext();
5049
RelationResolver relationResolver = mock(RelationResolver.class);
5150
MappingJdbcConverter converter = new MappingJdbcConverter(context, relationResolver);
52-
AnsiDialect dialect = AnsiDialect.INSTANCE;
5351
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter);
5452

5553
@Test // DATAJDBC-412
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package org.springframework.data.jdbc.core.mapping;
2+
3+
import static org.mockito.ArgumentMatchers.anyMap;
4+
import static org.mockito.ArgumentMatchers.anyString;
5+
import static org.mockito.Mockito.any;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.when;
8+
9+
import org.assertj.core.api.Assertions;
10+
import org.junit.jupiter.api.Test;
11+
import org.springframework.data.annotation.Id;
12+
import org.springframework.data.relational.core.conversion.MutableAggregateChange;
13+
import org.springframework.data.relational.core.dialect.MySqlDialect;
14+
import org.springframework.data.relational.core.dialect.PostgresDialect;
15+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
16+
import org.springframework.data.relational.core.mapping.Table;
17+
import org.springframework.data.relational.core.mapping.TargetSequence;
18+
import org.springframework.data.relational.core.sql.IdentifierProcessing;
19+
import org.springframework.jdbc.core.RowMapper;
20+
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
21+
22+
/**
23+
* Unit tests for {@link IdGeneratingBeforeSaveCallback}
24+
*
25+
* @author Mikhail Polivakha
26+
*/
27+
class IdGeneratingBeforeSaveCallbackTest {
28+
29+
@Test
30+
void test_mySqlDialect_sequenceGenerationIsNotSupported() {
31+
// given
32+
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
33+
MySqlDialect mySqlDialect = new MySqlDialect(IdentifierProcessing.NONE);
34+
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);
35+
36+
// and
37+
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);
38+
39+
NoSequenceEntity entity = new NoSequenceEntity();
40+
41+
// when
42+
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));
43+
44+
// then
45+
Assertions.assertThat(processed).isSameAs(entity);
46+
Assertions.assertThat(processed).usingRecursiveComparison().isEqualTo(entity);
47+
}
48+
49+
@Test
50+
void test_EntityIsNotMarkedWithTargetSequence() {
51+
// given
52+
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
53+
PostgresDialect mySqlDialect = PostgresDialect.INSTANCE;
54+
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);
55+
56+
// and
57+
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);
58+
59+
NoSequenceEntity entity = new NoSequenceEntity();
60+
61+
// when
62+
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));
63+
64+
// then
65+
Assertions.assertThat(processed).isSameAs(entity);
66+
Assertions.assertThat(processed).usingRecursiveComparison().isEqualTo(entity);
67+
}
68+
69+
@Test
70+
void test_EntityIdIsPopulatedFromSequence() {
71+
// given
72+
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
73+
relationalMappingContext.getRequiredPersistentEntity(EntityWithSequence.class);
74+
75+
PostgresDialect mySqlDialect = PostgresDialect.INSTANCE;
76+
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);
77+
78+
// and
79+
long generatedId = 112L;
80+
when(operations.queryForObject(anyString(), anyMap(), any(RowMapper.class))).thenReturn(generatedId);
81+
82+
// and
83+
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);
84+
85+
EntityWithSequence entity = new EntityWithSequence();
86+
87+
// when
88+
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));
89+
90+
// then
91+
Assertions.assertThat(processed).isSameAs(entity);
92+
Assertions
93+
.assertThat(processed)
94+
.usingRecursiveComparison()
95+
.ignoringFields("id")
96+
.isEqualTo(entity);
97+
Assertions.assertThat(entity.getId()).isEqualTo(generatedId);
98+
}
99+
100+
@Table
101+
static class NoSequenceEntity {
102+
103+
@Id
104+
private Long id;
105+
private Long name;
106+
}
107+
108+
@Table
109+
static class EntityWithSequence {
110+
111+
@Id
112+
@TargetSequence(value = "id_seq", schema = "public")
113+
private Long id;
114+
115+
private Long name;
116+
117+
public Long getId() {
118+
return id;
119+
}
120+
}
121+
}

0 commit comments

Comments
 (0)