2626import java .util .List ;
2727import java .util .Map ;
2828
29+ import org .jspecify .annotations .Nullable ;
2930import org .junit .jupiter .api .AfterEach ;
3031import org .junit .jupiter .api .Test ;
3132
3233import org .springframework .beans .factory .annotation .Autowired ;
34+ import org .springframework .context .ApplicationEvent ;
3335import org .springframework .context .ApplicationEventPublisher ;
3436import org .springframework .context .ConfigurableApplicationContext ;
3537import org .springframework .context .annotation .AnnotationConfigApplicationContext ;
4345import org .springframework .transaction .annotation .EnableTransactionManagement ;
4446import org .springframework .transaction .annotation .Propagation ;
4547import org .springframework .transaction .annotation .Transactional ;
48+ import org .springframework .transaction .config .GlobalTransactionalEventErrorHandler ;
4649import org .springframework .transaction .support .TransactionSynchronization ;
4750import org .springframework .transaction .support .TransactionSynchronizationManager ;
4851import org .springframework .transaction .support .TransactionTemplate ;
@@ -99,12 +102,12 @@ void immediately() {
99102 void immediatelyImpactsCurrentTransaction () {
100103 load (ImmediateTestListener .class , BeforeCommitTestListener .class );
101104 assertThatIllegalStateException ().isThrownBy (() ->
102- this .transactionTemplate .execute (status -> {
103- getContext ().publishEvent ("FAIL" );
104- throw new AssertionError ("Should have thrown an exception at this point" );
105- }))
106- .withMessageContaining ("Test exception" )
107- .withMessageContaining (EventCollector .IMMEDIATELY );
105+ this .transactionTemplate .execute (status -> {
106+ getContext ().publishEvent ("FAIL" );
107+ throw new AssertionError ("Should have thrown an exception at this point" );
108+ }))
109+ .withMessageContaining ("Test exception" )
110+ .withMessageContaining (EventCollector .IMMEDIATELY );
108111
109112 getEventCollector ().assertEvents (EventCollector .IMMEDIATELY , "FAIL" );
110113 getEventCollector ().assertTotalEventsCount (1 );
@@ -369,6 +372,45 @@ void conditionFoundOnMetaAnnotation() {
369372 getEventCollector ().assertNoEventReceived ();
370373 }
371374
375+ @ Test
376+ void afterCommitThrowException () {
377+ doLoad (HandlerConfiguration .class , AfterCommitErrorHandlerTestListener .class );
378+ this .transactionTemplate .execute (status -> {
379+ getContext ().publishEvent ("test" );
380+ getEventCollector ().assertNoEventReceived ();
381+ return null ;
382+ });
383+ getEventCollector ().assertEvents (EventCollector .AFTER_COMMIT , "test" );
384+ getEventCollector ().assertEvents (EventCollector .HANDLE_ERROR , "HANDLE_ERROR" );
385+ getEventCollector ().assertTotalEventsCount (2 );
386+ }
387+
388+ @ Test
389+ void afterRollbackThrowException () {
390+ doLoad (HandlerConfiguration .class , AfterRollbackErrorHandlerTestListener .class );
391+ this .transactionTemplate .execute (status -> {
392+ getContext ().publishEvent ("test" );
393+ getEventCollector ().assertNoEventReceived ();
394+ status .setRollbackOnly ();
395+ return null ;
396+ });
397+ getEventCollector ().assertEvents (EventCollector .AFTER_ROLLBACK , "test" );
398+ getEventCollector ().assertEvents (EventCollector .HANDLE_ERROR , "HANDLE_ERROR" );
399+ getEventCollector ().assertTotalEventsCount (2 );
400+ }
401+
402+ @ Test
403+ void afterCompletionThrowException () {
404+ doLoad (HandlerConfiguration .class , AfterCompletionErrorHandlerTestListener .class );
405+ this .transactionTemplate .execute (status -> {
406+ getContext ().publishEvent ("test" );
407+ getEventCollector ().assertNoEventReceived ();
408+ return null ;
409+ });
410+ getEventCollector ().assertEvents (EventCollector .AFTER_COMPLETION , "test" );
411+ getEventCollector ().assertEvents (EventCollector .HANDLE_ERROR , "HANDLE_ERROR" );
412+ getEventCollector ().assertTotalEventsCount (2 );
413+ }
372414
373415 protected EventCollector getEventCollector () {
374416 return this .eventCollector ;
@@ -442,6 +484,36 @@ public TransactionTemplate transactionTemplate() {
442484 }
443485 }
444486
487+ @ Configuration
488+ @ EnableTransactionManagement
489+ static class HandlerConfiguration {
490+
491+ @ Bean
492+ public EventCollector eventCollector () {
493+ return new EventCollector ();
494+ }
495+
496+ @ Bean
497+ public TestBean testBean (ApplicationEventPublisher eventPublisher ) {
498+ return new TestBean (eventPublisher );
499+ }
500+
501+ @ Bean
502+ public CallCountingTransactionManager transactionManager () {
503+ return new CallCountingTransactionManager ();
504+ }
505+
506+ @ Bean
507+ public TransactionTemplate transactionTemplate () {
508+ return new TransactionTemplate (transactionManager ());
509+ }
510+
511+ @ Bean
512+ public AfterRollbackErrorHandler errorHandler (ApplicationEventPublisher eventPublisher ) {
513+ return new AfterRollbackErrorHandler (eventPublisher );
514+ }
515+ }
516+
445517
446518 @ Configuration
447519 static class MulticasterWithCustomExecutor {
@@ -467,7 +539,9 @@ static class EventCollector {
467539
468540 public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK" ;
469541
470- public static final String [] ALL_PHASES = {IMMEDIATELY , BEFORE_COMMIT , AFTER_COMMIT , AFTER_ROLLBACK };
542+ public static final String HANDLE_ERROR = "HANDLE_ERROR" ;
543+
544+ public static final String [] ALL_PHASES = {IMMEDIATELY , BEFORE_COMMIT , AFTER_COMMIT , AFTER_ROLLBACK , HANDLE_ERROR };
471545
472546 private final MultiValueMap <String , Object > events = new LinkedMultiValueMap <>();
473547
@@ -486,7 +560,7 @@ public void assertNoEventReceived(String... phases) {
486560 for (String phase : phases ) {
487561 List <Object > eventsForPhase = getEvents (phase );
488562 assertThat (eventsForPhase .size ()).as ("Expected no events for phase '" + phase + "' " +
489- "but got " + eventsForPhase + ":" ).isEqualTo (0 );
563+ "but got " + eventsForPhase + ":" ).isEqualTo (0 );
490564 }
491565 }
492566
@@ -504,7 +578,7 @@ public void assertTotalEventsCount(int number) {
504578 size += entry .getValue ().size ();
505579 }
506580 assertThat (size ).as ("Wrong number of total events (" + this .events .size () + ") " +
507- "registered phase(s)" ).isEqualTo (number );
581+ "registered phase(s)" ).isEqualTo (number );
508582 }
509583 }
510584
@@ -677,6 +751,51 @@ public void handleAfterCommit(String data) {
677751 }
678752
679753
754+ @ Component
755+ static class AfterCommitErrorHandlerTestListener extends BaseTransactionalTestListener {
756+
757+ @ TransactionalEventListener (phase = AFTER_COMMIT , condition = "!'HANDLE_ERROR'.equals(#data)" )
758+ public void handleBeforeCommit (String data ) {
759+ handleEvent (EventCollector .AFTER_COMMIT , data );
760+ throw new IllegalStateException ("test" );
761+ }
762+
763+ @ EventListener (condition = "'HANDLE_ERROR'.equals(#data)" )
764+ public void handleImmediately (String data ) {
765+ handleEvent (EventCollector .HANDLE_ERROR , data );
766+ }
767+ }
768+
769+ @ Component
770+ static class AfterRollbackErrorHandlerTestListener extends BaseTransactionalTestListener {
771+
772+ @ TransactionalEventListener (phase = AFTER_ROLLBACK , condition = "!'HANDLE_ERROR'.equals(#data)" )
773+ public void handleBeforeCommit (String data ) {
774+ handleEvent (EventCollector .AFTER_ROLLBACK , data );
775+ throw new IllegalStateException ("test" );
776+ }
777+
778+ @ EventListener (condition = "'HANDLE_ERROR'.equals(#data)" )
779+ public void handleImmediately (String data ) {
780+ handleEvent (EventCollector .HANDLE_ERROR , data );
781+ }
782+ }
783+
784+ @ Component
785+ static class AfterCompletionErrorHandlerTestListener extends BaseTransactionalTestListener {
786+
787+ @ TransactionalEventListener (phase = AFTER_COMPLETION , condition = "!'HANDLE_ERROR'.equals(#data)" )
788+ public void handleBeforeCommit (String data ) {
789+ handleEvent (EventCollector .AFTER_COMPLETION , data );
790+ throw new IllegalStateException ("test" );
791+ }
792+
793+ @ EventListener (condition = "'HANDLE_ERROR'.equals(#data)" )
794+ public void handleImmediately (String data ) {
795+ handleEvent (EventCollector .HANDLE_ERROR , data );
796+ }
797+ }
798+
680799 static class EventTransactionSynchronization implements TransactionSynchronization {
681800
682801 private final int order ;
@@ -691,4 +810,18 @@ public int getOrder() {
691810 }
692811 }
693812
813+ static class AfterRollbackErrorHandler extends GlobalTransactionalEventErrorHandler {
814+
815+ private final ApplicationEventPublisher eventPublisher ;
816+
817+ AfterRollbackErrorHandler (ApplicationEventPublisher eventPublisher ) {
818+ this .eventPublisher = eventPublisher ;
819+ }
820+
821+ @ Override
822+ public void handle (ApplicationEvent event , @ Nullable Throwable ex ) {
823+ eventPublisher .publishEvent ("HANDLE_ERROR" );
824+ }
825+ }
826+
694827}
0 commit comments