Skip to content

Commit a1b6425

Browse files
committed
test/func: added squeeze morph class and call it in the test
1 parent ca69295 commit a1b6425

File tree

2 files changed

+69
-33
lines changed

2 files changed

+69
-33
lines changed

Diff for: src/diffpy/morph/morphs/morphsqueeze.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
from numpy.polynomial import Polynomial
3+
from scipy.interpolate import interp1d
4+
5+
from diffpy.morph.morphs.morph import LABEL_GR, LABEL_RA, Morph
6+
7+
8+
class MorphSqueeze(Morph):
9+
"""Squeeze the morph function.
10+
11+
This applies a polynomial to squeeze the morph non-linearly.
12+
13+
Configuration Variables
14+
-----------------------
15+
squeeze
16+
list or array-like
17+
Polynomial coefficients [a0, a1, ..., an] for the squeeze function.
18+
"""
19+
20+
# Define input output types
21+
summary = "Squeeze morph by polynomial shift"
22+
xinlabel = LABEL_RA
23+
yinlabel = LABEL_GR
24+
xoutlabel = LABEL_RA
25+
youtlabel = LABEL_GR
26+
parnames = ["squeeze"]
27+
28+
def morph(self, x_morph, y_morph, x_target, y_target):
29+
30+
Morph.morph(self, x_morph, y_morph, x_target, y_target)
31+
if self.squeeze is None or np.allclose(self.squeeze, 0):
32+
self.x_morph_out = self.x_morph_in
33+
self.y_morph_out = self.y_morph_in
34+
return self.xyallout
35+
36+
squeeze_polynomial = Polynomial(self.squeeze)
37+
x_squeezed = self.x_morph_in + squeeze_polynomial(self.x_morph_in)
38+
39+
self.y_morph_out = interp1d(
40+
x_squeezed,
41+
self.y_morph_in,
42+
kind="cubic",
43+
bounds_error=False,
44+
fill_value="extrapolate",
45+
)(self.x_morph_in)
46+
self.x_morph_out = self.x_morph_in
47+
48+
return self.xyallout

Diff for: tests/test_morphsqueeze.py

+21-33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import pytest
33
from numpy.polynomial import Polynomial
4-
from scipy.interpolate import interp1d
4+
5+
from diffpy.morph.morphs.morphsqueeze import MorphSqueeze
56

67

78
@pytest.mark.parametrize(
@@ -20,41 +21,28 @@
2021
[0.1, 0.3],
2122
# 4th order squeeze coefficients
2223
[0.2, -0.01, 0.001, -0.001, 0.0001],
24+
# Testing zeros
25+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
2326
],
2427
)
2528
def test_morphsqueeze(squeeze_coeffs):
26-
# Uniform x-axis grid. This is the same x-axis for all data.
27-
x = np.linspace(0, 10, 1000)
28-
# Expected uniform target
29-
y_expected = np.sin(x)
3029

31-
# Create polynomial based on a list of values for polynomial coefficients
30+
x_target = np.linspace(0, 10, 1000)
31+
y_target = np.sin(x_target)
32+
3233
squeeze_polynomial = Polynomial(squeeze_coeffs)
33-
# Apply squeeze parameters to uniform data to get the squeezed data
34-
x_squeezed = x + squeeze_polynomial(x)
35-
y_squeezed = np.sin(x_squeezed)
36-
37-
# Unsqueeze the data by interpolating back to uniform grid
38-
y_unsqueezed = interp1d(
39-
x_squeezed,
40-
y_squeezed,
41-
kind="cubic",
42-
bounds_error=False,
43-
fill_value="extrapolate",
44-
)(x)
45-
y_actual = y_unsqueezed
46-
47-
# Check that the unsqueezed (actual) data matches the expected data
34+
x_squeezed = x_target + squeeze_polynomial(x_target)
35+
36+
x_morph = x_target.copy()
37+
y_morph = np.sin(x_squeezed)
38+
39+
morph = MorphSqueeze()
40+
morph.squeeze = squeeze_coeffs
41+
42+
x_actual, y_actual, x_expected, y_expected = morph(
43+
x_morph, y_morph, x_target, y_target
44+
)
45+
46+
# Check that the morphed (actual) data matches the expected data
4847
# Including tolerance error because of extrapolation error
49-
assert np.allclose(y_actual, y_expected, atol=1)
50-
51-
# This plotting code was used for the comments in the github
52-
# PR https://github.com/diffpy/diffpy.morph/pull/180
53-
# plt.figure(figsize=(7, 4))
54-
# plt.plot(x, y_expected, color="black", label="Expected uniform data")
55-
# plt.plot(x, y_squeezed, "--", color="purple", label="Squeezed data")
56-
# plt.plot(x, y_unsqueezed, "--", color="gold", label="Unsqueezed data")
57-
# plt.xlabel("x")
58-
# plt.ylabel("y")
59-
# plt.legend()
60-
# plt.show()
48+
assert np.allclose(y_actual, y_expected, atol=0.1)

0 commit comments

Comments
 (0)