|
1 | 1 | import numpy as np
|
2 | 2 | import pytest
|
3 | 3 | from numpy.polynomial import Polynomial
|
| 4 | +from scipy.interpolate import interp1d |
4 | 5 |
|
5 | 6 | from diffpy.morph.morphs.morphsqueeze import MorphSqueeze
|
6 | 7 |
|
|
21 | 22 | [0.1, 0.3],
|
22 | 23 | # 4th order squeeze coefficients
|
23 | 24 | [0.2, -0.01, 0.001, -0.001, 0.0004],
|
24 |
| - # Zeros and non-zeros, expect 0 + a1x + 0 + a3x**3 |
| 25 | + # Zeros and non-zeros, the full polynomial is applied |
25 | 26 | [0, 0.03, 0, -0.001],
|
26 | 27 | # Testing zeros, expect no squeezing
|
27 |
| - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 28 | + [0, 0, 0, 0, 0, 0], |
28 | 29 | ],
|
29 | 30 | )
|
30 | 31 | def test_morphsqueeze(squeeze_coeffs):
|
31 |
| - x_target = np.linspace(0, 10, 1001) |
32 |
| - y_target = np.sin(x_target) |
33 |
| - |
34 |
| - x_make = np.linspace(-3, 13, 1601) |
35 |
| - lower_idx = np.where(x_make == 0.0)[0][0] |
36 |
| - upper_idx = np.where(x_make == 10.0)[0][0] |
37 |
| - |
| 32 | + x_expected = np.linspace(0, 10, 1001) |
| 33 | + y_expected = np.sin(x_expected) |
| 34 | + x_make = np.linspace(-3, 13, 3250) |
38 | 35 | squeeze_polynomial = Polynomial(squeeze_coeffs)
|
39 | 36 | x_squeezed = x_make + squeeze_polynomial(x_make)
|
40 |
| - |
41 |
| - x_morph = x_make.copy() |
42 | 37 | y_morph = np.sin(x_squeezed)
|
43 |
| - |
44 | 38 | morph = MorphSqueeze()
|
45 | 39 | morph.squeeze = squeeze_coeffs
|
46 |
| - |
47 |
| - x_actual, y_actual, x_expected, y_expected = morph( |
48 |
| - x_morph, y_morph, x_target, y_target |
| 40 | + x_actual, y_actual, x_target, y_target = morph( |
| 41 | + x_make, y_morph, x_expected, y_expected |
49 | 42 | )
|
50 |
| - y_actual = y_actual[lower_idx : upper_idx + 1] |
| 43 | + y_actual = interp1d(x_actual, y_actual)(x_target) |
| 44 | + x_actual = x_target |
51 | 45 | assert np.allclose(y_actual, y_expected)
|
| 46 | + assert np.allclose(x_actual, x_expected) |
| 47 | + assert np.allclose(x_target, x_expected) |
| 48 | + assert np.allclose(y_target, y_expected) |
52 | 49 |
|
53 | 50 | # Plotting code used for figures in PR comments
|
54 | 51 | # https://github.com/diffpy/diffpy.morph/pull/180
|
55 | 52 | # plt.figure()
|
56 | 53 | # plt.scatter(x_expected, y_expected, color='black', label='Expected')
|
57 |
| - # plt.plot(x_morph, y_morph, color='purple', label='morph') |
| 54 | + # plt.plot(x_make, y_morph, color='purple', label='morph') |
58 | 55 | # plt.plot(x_actual, y_actual, '--', color='gold', label='Actual')
|
59 | 56 | # plt.legend()
|
60 | 57 | # plt.show()
|
0 commit comments