Skip to content

Commit e9795bb

Browse files
committed
TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupplier of Container by either direct field usage or a reflection equivalent.
If the field is private, the reflection will be used; otherwise, direct access to the field will be used
1 parent 4718485 commit e9795bb

File tree

5 files changed

+234
-3
lines changed

5 files changed

+234
-3
lines changed

spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java

+108
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@
1818

1919
import java.lang.annotation.Retention;
2020
import java.lang.annotation.RetentionPolicy;
21+
import java.util.function.BiConsumer;
2122

2223
import org.junit.jupiter.api.AfterEach;
2324
import org.junit.jupiter.api.Test;
2425
import org.testcontainers.containers.Container;
2526
import org.testcontainers.containers.PostgreSQLContainer;
2627

28+
import org.springframework.aot.test.generate.TestGenerationContext;
2729
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
2830
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
2931
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
3032
import org.springframework.boot.testsupport.container.TestImage;
33+
import org.springframework.context.ApplicationContextInitializer;
3134
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
35+
import org.springframework.context.aot.ApplicationContextAotGenerator;
36+
import org.springframework.context.support.GenericApplicationContext;
37+
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
38+
import org.springframework.core.test.tools.Compiled;
39+
import org.springframework.core.test.tools.TestCompiler;
40+
import org.springframework.javapoet.ClassName;
3241
import org.springframework.test.context.DynamicPropertyRegistry;
3342
import org.springframework.test.context.DynamicPropertySource;
3443

@@ -43,6 +52,8 @@
4352
@DisabledIfDockerUnavailable
4453
class ImportTestcontainersTests {
4554

55+
private final TestGenerationContext generationContext = new TestGenerationContext();
56+
4657
private AnnotationConfigApplicationContext applicationContext;
4758

4859
@AfterEach
@@ -122,6 +133,81 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
122133
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
123134
}
124135

136+
@Test
137+
@CompileWithForkedClassLoader
138+
void importTestcontainersImportWithoutValueAotContribution() {
139+
this.applicationContext = new AnnotationConfigApplicationContext();
140+
this.applicationContext.register(ImportWithoutValue.class);
141+
compile((freshContext, compiled) -> {
142+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
143+
assertThat(container).isSameAs(ImportWithoutValue.container);
144+
});
145+
}
146+
147+
@Test
148+
@CompileWithForkedClassLoader
149+
void importTestcontainersImportWithValueAotContribution() {
150+
this.applicationContext = new AnnotationConfigApplicationContext();
151+
this.applicationContext.register(ImportWithValue.class);
152+
compile((freshContext, compiled) -> {
153+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
154+
assertThat(container).isSameAs(ContainerDefinitions.container);
155+
});
156+
}
157+
158+
@Test
159+
@CompileWithForkedClassLoader
160+
void importTestcontainersWithDynamicPropertySourceAotContribution() {
161+
this.applicationContext = new AnnotationConfigApplicationContext();
162+
this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class);
163+
compile((freshContext, compiled) -> {
164+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
165+
assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
166+
});
167+
}
168+
169+
@Test
170+
@CompileWithForkedClassLoader
171+
void importTestcontainersWithCustomPostgreSQLContainerAotContribution() {
172+
this.applicationContext = new AnnotationConfigApplicationContext();
173+
this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class);
174+
compile((freshContext, compiled) -> {
175+
CustomPostgreSQLContainer container = freshContext.getBean(CustomPostgreSQLContainer.class);
176+
assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container);
177+
});
178+
}
179+
180+
@Test
181+
@CompileWithForkedClassLoader
182+
void importTestcontainersWithNotAccessibleContainerAotContribution() {
183+
this.applicationContext = new AnnotationConfigApplicationContext();
184+
this.applicationContext.register(ImportNotAccessibleContainer.class);
185+
compile((freshContext, compiled) -> {
186+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
187+
assertThat(container).isSameAs(ImportNotAccessibleContainer.container);
188+
});
189+
}
190+
191+
@SuppressWarnings("unchecked")
192+
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
193+
ClassName className = processAheadOfTime();
194+
TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> {
195+
GenericApplicationContext freshApplicationContext = new GenericApplicationContext();
196+
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
197+
.getInstance(ApplicationContextInitializer.class, className.toString());
198+
initializer.initialize(freshApplicationContext);
199+
freshApplicationContext.refresh();
200+
result.accept(freshApplicationContext, compiled);
201+
});
202+
}
203+
204+
private ClassName processAheadOfTime() {
205+
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
206+
this.generationContext);
207+
this.generationContext.writeGeneratedContent();
208+
return className;
209+
}
210+
125211
@ImportTestcontainers
126212
static class ImportWithoutValue {
127213

@@ -196,4 +282,26 @@ void containerProperties() {
196282

197283
}
198284

285+
@ImportTestcontainers
286+
static class CustomPostgreSQLContainerDefinitions {
287+
288+
static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer();
289+
290+
}
291+
292+
static class CustomPostgreSQLContainer extends PostgreSQLContainer<CustomPostgreSQLContainer> {
293+
294+
CustomPostgreSQLContainer() {
295+
super("postgres:14");
296+
}
297+
298+
}
299+
300+
@ImportTestcontainers
301+
static class ImportNotAccessibleContainer {
302+
303+
private static final PostgreSQLContainer<?> container = TestImage.container(PostgreSQLContainer.class);
304+
305+
}
306+
199307
}

spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2023 the original author or authors.
2+
* Copyright 2012-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes
3838
TestcontainerFieldBeanDefinition(Field field, Container<?> container) {
3939
this.container = container;
4040
this.annotations = MergedAnnotations.from(field);
41-
this.setBeanClass(container.getClass());
41+
setBeanClass(container.getClass());
4242
setInstanceSupplier(() -> container);
4343
setRole(ROLE_INFRASTRUCTURE);
44+
setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field);
4445
}
4546

4647
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2012-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.boot.testcontainers.context;
18+
19+
import java.lang.reflect.Field;
20+
21+
import javax.lang.model.element.Modifier;
22+
23+
import org.testcontainers.containers.Container;
24+
25+
import org.springframework.aot.generate.AccessControl;
26+
import org.springframework.aot.generate.GeneratedMethod;
27+
import org.springframework.aot.generate.GenerationContext;
28+
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
29+
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
30+
import org.springframework.beans.factory.aot.BeanRegistrationCode;
31+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
32+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
33+
import org.springframework.beans.factory.support.RegisteredBean;
34+
import org.springframework.beans.factory.support.RootBeanDefinition;
35+
import org.springframework.javapoet.ClassName;
36+
import org.springframework.javapoet.CodeBlock;
37+
import org.springframework.util.Assert;
38+
import org.springframework.util.ReflectionUtils;
39+
40+
/**
41+
* {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of
42+
* {@link Container} by either direct field usage or a reflection equivalent.
43+
* <p>
44+
* If the field is private, the reflection will be used; otherwise, direct access to the
45+
* field will be used.
46+
*
47+
* @author Dmytro Nosan
48+
*/
49+
class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {
50+
51+
@Override
52+
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
53+
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
54+
String attributeName = TestcontainerFieldBeanDefinition.class.getName();
55+
Object field = bd.getAttribute(attributeName);
56+
if (field != null) {
57+
Assert.isInstanceOf(Field.class, field,
58+
"BeanDefinition attribute '" + attributeName + "' value must be a type of '" + Field.class + "'");
59+
return BeanRegistrationAotContribution.withCustomCodeFragments(
60+
(codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field)));
61+
}
62+
return null;
63+
}
64+
65+
static class AotContribution extends BeanRegistrationCodeFragmentsDecorator {
66+
67+
private final RegisteredBean registeredBean;
68+
69+
private final Field field;
70+
71+
AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) {
72+
super(delegate);
73+
this.registeredBean = registeredBean;
74+
this.field = field;
75+
}
76+
77+
@Override
78+
public ClassName getTarget(RegisteredBean registeredBean) {
79+
return ClassName.get(this.field.getDeclaringClass());
80+
}
81+
82+
@Override
83+
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
84+
BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
85+
if (AccessControl.forMember(this.field).isAccessibleFrom(beanRegistrationCode.getClassName())) {
86+
return CodeBlock.of("() -> $T.$L", this.field.getDeclaringClass(), this.field.getName());
87+
}
88+
generationContext.getRuntimeHints().reflection().registerField(this.field);
89+
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
90+
.add("getInstance", (method) -> method.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
91+
.addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName())
92+
.returns(this.registeredBean.getBeanClass())
93+
.addStatement("$T field = $T.findField($T.class, $S)", Field.class, ReflectionUtils.class,
94+
this.field.getDeclaringClass(), this.field.getName())
95+
.addStatement("$T.notNull(field, $S)", Assert.class,
96+
"Field '" + this.field.getName() + "' is not found")
97+
.addStatement("$T.makeAccessible(field)", ReflectionUtils.class)
98+
.addStatement("$T container = $T.getField(field, null)", Object.class, ReflectionUtils.class)
99+
.addStatement("$T.notNull(container, $S)", Assert.class,
100+
"Container field '" + this.field.getName() + "' must not have a null value")
101+
.addStatement("return ($T) container", this.registeredBean.getBeanClass()));
102+
return generatedMethod.toMethodReference().toCodeBlock();
103+
}
104+
105+
}
106+
107+
}

spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java

+11
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import org.testcontainers.containers.Container;
2727

2828
import org.springframework.beans.BeansException;
29+
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
2930
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
3031
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3132
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
33+
import org.springframework.beans.factory.support.RegisteredBean;
3234
import org.springframework.beans.factory.support.RootBeanDefinition;
3335
import org.springframework.context.ApplicationEventPublisher;
3436
import org.springframework.context.ApplicationEventPublisherAware;
@@ -166,4 +168,13 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
166168

167169
}
168170

171+
static class TestcontainersEventPublisherBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {
172+
173+
@Override
174+
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
175+
return EventPublisherRegistrar.NAME.equals(registeredBean.getBeanName());
176+
}
177+
178+
}
179+
169180
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\
2-
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter
2+
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter,\
3+
org.springframework.boot.testcontainers.properties.TestcontainersPropertySource.TestcontainersEventPublisherBeanRegistrationExcludeFilter
34

45
org.springframework.aot.hint.RuntimeHintsRegistrar=\
56
org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints
7+
8+
org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\
9+
org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor

0 commit comments

Comments
 (0)