Skip to content

Commit

Permalink
Merge pull request #344 from Serrof/issue-335
Browse files Browse the repository at this point in the history
add propagation direction in (Field)AdaptableInterval
  • Loading branch information
Serrof authored Sep 21, 2024
2 parents 9bad1c2 + 0a18c06 commit c3e5d3a
Show file tree
Hide file tree
Showing 27 changed files with 240 additions and 59 deletions.
5 changes: 5 additions & 0 deletions hipparchus-ode/src/changes/changes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ If the output is not quite correct, check for invisible trailing spaces!
</action>
</release>
<body>
<release version="4.0" date="TBD" description="TBD">
<action dev="serrof" type="add" issue="issues/335">
Add boolean for propagation direction in (Field)AdaptableInterval.
</action>
</release>
<release version="3.1" date="2024-04-05" description="This is a maintenance release. It includes one
bugfixes and adds two features on existing integrators.">
<action dev="serrof" type="add" issue="issues/269">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public BracketedRealFieldUnivariateSolver<E> getSolver() {
* @return a new detector with updated configuration (the instance is not changed)
*/
public T withMaxCheck(final E newMaxCheck) {
return withMaxCheck(s -> newMaxCheck.getReal());
return withMaxCheck((s, isForward) -> newMaxCheck.getReal());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public BracketedUnivariateSolver<UnivariateFunction> getSolver() {
* @return a new detector with updated configuration (the instance is not changed)
*/
public T withMaxCheck(final double newMaxCheck) {
return withMaxCheck(s -> newMaxCheck);
return withMaxCheck((s, isForward) -> newMaxCheck);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
@FunctionalInterface
public interface AdaptableInterval {

/** Get the current value of maximal time interval between events handler checks.
* @param state current state
/**
* Get the current value of maximal time interval between events handler checks.
*
* @param state current state
* @param isForward true if propagation is forward in independent variable, false otherwise
* @return current value of maximal time interval between events handler checks
*/
double currentInterval(ODEStateAndDerivative state);
double currentInterval(ODEStateAndDerivative state, boolean isForward);

}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ private ODEStateAndDerivative nextCheck(final ODEStateAndDerivative done, final
// we have to select some intermediate state
// attempting to split the remaining time in an integer number of checks
final double dt = target.getTime() - done.getTime();
final double maxCheck = detector.getMaxCheckInterval().currentInterval(done);
final double maxCheck = detector.getMaxCheckInterval().currentInterval(done, dt >= 0.);
final int n = FastMath.max(1, (int) FastMath.ceil(FastMath.abs(dt) / maxCheck));
return n == 1 ? target : interpolator.getInterpolatedState(done.getTime() + dt / n);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
@FunctionalInterface
public interface FieldAdaptableInterval<T extends CalculusFieldElement<T>> {

/** Get the current value of maximal time interval between events handler checks.
* @param state current state
/**
* Get the current value of maximal time interval between events handler checks.
*
* @param state current state
* @param isForward true if propagation is forward in independent variable, false otherwise
* @return current value of maximal time interval between events handler checks (only as a double)
*/
double currentInterval(FieldODEStateAndDerivative<T> state);
double currentInterval(FieldODEStateAndDerivative<T> state, boolean isForward);

}
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ private FieldODEStateAndDerivative<T> nextCheck(final FieldODEStateAndDerivative
// we have to select some intermediate state
// attempting to split the remaining time in an integer number of checks
final T dt = target.getTime().subtract(done.getTime());
final double maxCheck = detector.getMaxCheckInterval().currentInterval(done);
final double maxCheck = detector.getMaxCheckInterval().currentInterval(done, dt.getReal() >= 0.);
final int n = FastMath.max(1, (int) FastMath.ceil(dt.abs().divide(maxCheck).getReal()));
return n == 1 ? target : interpolator.getInterpolatedState(done.getTime().add(dt.divide(n)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ protected BaseDetector(final double maxCheck, final T threshold, final int maxIt
}

public FieldAdaptableInterval<T> getMaxCheckInterval() {
return s -> maxCheck;
return (s, isForward) -> maxCheck;
}

public int getMaxIterationCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ protected BaseDetector(final double maxCheck, final double threshold, final int
}

public AdaptableInterval getMaxCheckInterval() {
return s -> maxCheck;
return (s, isForward) -> maxCheck;
}

public int getMaxIterationCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2098,7 +2098,7 @@ private static abstract class BaseDetector implements ODEEventDetector {

public BaseDetector(final double maxCheck, final double threshold, final int maxIter,
Action action, List<Event> events) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.action = action;
Expand Down Expand Up @@ -2382,7 +2382,7 @@ private class ResetChangesSignGenerator implements ODEEventDetector {

public ResetChangesSignGenerator(final double y1, final double y2, final double change,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.y1 = y1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
import org.hipparchus.ode.nonstiff.LutherIntegrator;
import org.hipparchus.ode.sampling.DummyStepInterpolator;
import org.hipparchus.ode.sampling.ODEStateInterpolator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -134,7 +138,7 @@ private static class ResettingEvent implements ODEEventDetector {

public ResettingEvent(final double tEvent,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.tEvent = tEvent;
Expand Down Expand Up @@ -228,7 +232,7 @@ private static class SecondaryStateEvent implements ODEEventDetector {

public SecondaryStateEvent(final int index, final double target,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.index = index;
Expand Down Expand Up @@ -294,7 +298,7 @@ private class CloseEventsGenerator implements ODEEventDetector {

public CloseEventsGenerator(final double r1, final double r2,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.r1 = r1;
Expand Down Expand Up @@ -328,4 +332,75 @@ public int getCount() {

}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testAdaptableInterval(final boolean isForward) {
// GIVEN
final TestDetector detector = new TestDetector();
final DetectorBasedEventState eventState = new DetectorBasedEventState(detector);
final ODEStateInterpolator mockedInterpolator = Mockito.mock(ODEStateInterpolator.class);
final ODEStateAndDerivative stateAndDerivative1 = getStateAndDerivative(1);
final ODEStateAndDerivative stateAndDerivative2 = getStateAndDerivative(-1);
if (isForward) {
Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative1);
Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative2);
} else {
Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative2);
Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative1);
}
Mockito.when(mockedInterpolator.isForward()).thenReturn(isForward);
// WHEN
eventState.evaluateStep(mockedInterpolator);
// THEN
if (isForward) {
Assertions.assertEquals(1, detector.triggeredForward);
Assertions.assertEquals(0, detector.triggeredBackward);
} else {
Assertions.assertEquals(0, detector.triggeredForward);
Assertions.assertEquals(1, detector.triggeredBackward);
}
}

private static ODEStateAndDerivative getStateAndDerivative(final double time) {
return new ODEStateAndDerivative(time, new double[] {time}, new double[1]);
}

private static class TestDetector implements ODEEventDetector {

int triggeredForward = 0;
int triggeredBackward = 0;

@Override
public AdaptableInterval getMaxCheckInterval() {
return (state, isForward) -> {
if (isForward) {
triggeredForward++;
} else {
triggeredBackward++;
}
return 1.;
};
}

@Override
public int getMaxIterationCount() {
return 10;
}

@Override
public BracketedUnivariateSolver<UnivariateFunction> getSolver() {
return new BracketingNthOrderBrentSolver();
}

@Override
public ODEEventHandler getHandler() {
return null;
}

@Override
public double g(ODEStateAndDerivative state) {
return 0.;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ protected static class Event implements ODEEventDetector {

public Event(final double maxCheck, final double threshold, final int maxIter,
boolean expectDecreasing, boolean expectIncreasing) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.expectDecreasing = expectDecreasing;
Expand Down Expand Up @@ -505,7 +505,7 @@ protected static class FieldEvent<T extends CalculusFieldElement<T>> implements

public FieldEvent(final double maxCheck, final T threshold, final int maxIter,
boolean expectDecreasing, boolean expectIncreasing) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new FieldBracketingNthOrderBrentSolver<>(threshold.getField().getZero(),
threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ private static class SimpleDetector implements ODEEventDetector {
private final ScheduleChecker checker;
SimpleDetector(final double tEvent, final ScheduleChecker checker,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new BracketingNthOrderBrentSolver(0, threshold, 0, 5);
this.tEvent = tEvent;
Expand Down Expand Up @@ -226,7 +226,7 @@ private static class SimpleFieldDetector implements FieldODEEventDetector<Binary

SimpleFieldDetector(final double tEvent, final ScheduleChecker checker,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new FieldBracketingNthOrderBrentSolver<>(new Binary64(0),
new Binary64(threshold),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2132,7 +2132,7 @@ private static abstract class BaseDetector implements FieldODEEventDetector<Bina

public BaseDetector(final double maxCheck, final double threshold, final int maxIter,
Action action, List<Event> events) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new FieldBracketingNthOrderBrentSolver<>(new Binary64(0),
new Binary64(threshold),
Expand Down Expand Up @@ -2427,7 +2427,7 @@ private class ResetChangesSignGenerator implements FieldODEEventDetector<Binary6

public ResetChangesSignGenerator(final double y1, final double y2, final double change,
final double maxCheck, final double threshold, final int maxIter) {
this.maxCheck = s -> maxCheck;
this.maxCheck = (s, isForward) -> maxCheck;
this.maxIter = maxIter;
this.solver = new FieldBracketingNthOrderBrentSolver<>(new Binary64(0),
new Binary64(threshold),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package org.hipparchus.ode.events;

import org.hipparchus.analysis.solvers.BracketedRealFieldUnivariateSolver;
import org.hipparchus.analysis.solvers.FieldBracketingNthOrderBrentSolver;
import org.hipparchus.complex.Complex;
import org.hipparchus.complex.ComplexField;
import org.hipparchus.ode.FieldODEStateAndDerivative;
import org.hipparchus.ode.sampling.FieldODEStateInterpolator;
import org.hipparchus.util.MathArrays;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito;

class FieldDetectorBasedEventStateTest {

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testNextCheck(final boolean isForward) {
// GIVEN
final TestFieldDetector detector = new TestFieldDetector(isForward);
final FieldDetectorBasedEventState<Complex> eventState = new FieldDetectorBasedEventState<>(detector);
final FieldODEStateInterpolator<Complex> mockedInterpolator = Mockito.mock(FieldODEStateInterpolator.class);
final FieldODEStateAndDerivative<Complex> stateAndDerivative1 = getStateAndDerivative(1);
final FieldODEStateAndDerivative<Complex> stateAndDerivative2 = getStateAndDerivative(-1);
if (isForward) {
Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative1);
Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative2);
} else {
Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative2);
Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative1);
}
Mockito.when(mockedInterpolator.isForward()).thenReturn(isForward);
Mockito.when(mockedInterpolator.getInterpolatedState(new Complex(0.))).thenReturn(getStateAndDerivative(0.));
eventState.init(mockedInterpolator.getPreviousState(), mockedInterpolator.getPreviousState().getTime());
eventState.reinitializeBegin(mockedInterpolator);
// WHEN & THEN
final AssertionError error = Assertions.assertThrows(AssertionError.class, () ->
eventState.evaluateStep(mockedInterpolator));
Assertions.assertEquals(isForward ? "forward" : "backward", error.getMessage());
}

private static FieldODEStateAndDerivative<Complex> getStateAndDerivative(final double time) {
final Complex[] state = MathArrays.buildArray(ComplexField.getInstance(), 1);
state[0] = new Complex(time);
final Complex[] derivative = MathArrays.buildArray(ComplexField.getInstance(), 1);
derivative[0] = Complex.ONE;
return new FieldODEStateAndDerivative<>(state[0], state, derivative);
}

private static class TestFieldDetector implements FieldODEEventDetector<Complex> {

private final boolean failOnForward;

TestFieldDetector(final boolean failOnForward) {
this.failOnForward = failOnForward;
}

@Override
public FieldAdaptableInterval<Complex> getMaxCheckInterval() {
return (state, isForward) -> {
if (isForward && failOnForward) {
throw new AssertionError("forward");
} else if (!isForward && !failOnForward) {
throw new AssertionError("backward");
}
return 1.;
};
}

@Override
public int getMaxIterationCount() {
return 10;
}

@Override
public BracketedRealFieldUnivariateSolver<Complex> getSolver() {
return new FieldBracketingNthOrderBrentSolver<>(Complex.ONE, Complex.ONE, Complex.ONE, 2);
}

@Override
public FieldODEEventHandler<Complex> getHandler() {
return null;
}

@Override
public Complex g(FieldODEStateAndDerivative<Complex> state) {
return state.getTime();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ public class FieldVariableCheckInterval implements FieldOrdinaryDifferentialEqua
void testFixedInterval() {
double tZero = 7.0;
double width = 0.25;
doTest(tZero, width, s -> width / 25, 710);
doTest(tZero, width, (s, isForward) -> width / 25, 710);
}

@Test
void testWidthAwareInterval() {
double tZero = 7.0;
double width = 0.25;
doTest(tZero, width,
s -> {
(s, isForward) -> {
if (s.getTime().getReal() < tZero - 0.5 * width) {
return tZero - 0.25 * width - s.getTime().getReal();
} else if (s.getTime().getReal() > tZero + 0.5 * width) {
Expand Down
Loading

0 comments on commit c3e5d3a

Please sign in to comment.