Skip to content

Commit 3d34cd9

Browse files
spjuhelpeanutfun
andauthored
Implement equality methods for impf and impfset (#1027)
* Implement equality methods for impf and impfset * Update tests --------- Co-authored-by: Lukas Riedel <[email protected]>
1 parent af793bc commit 3d34cd9

File tree

5 files changed

+150
-7
lines changed

5 files changed

+150
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Removed:
1919
`plot_intensity`, `plot_fraction`, `_event_plot` to mask plotting when regions are too far from data points [#1047](https://github.com/CLIMADA-project/climada_python/pull/1047). To recreate previous plots (no masking), the parameter can be set to None.
2020
- Added instructions to install Climada petals on Euler cluster in `doc.guide.Guide_Euler.ipynb` [#1029](https://github.com/CLIMADA-project/climada_python/pull/1029)
2121

22+
- `ImpactFunc` and `ImpactFuncSet` now support equality comparisons via `==` [#1027](https://github.com/CLIMADA-project/climada_python/pull/1027)
23+
2224
### Changed
2325

2426
- `Hazard.local_exceedance_intensity`, `Hazard.local_return_period` and `Impact.local_exceedance_impact`, `Impact.local_return_period`, using the `climada.util.interpolation` module: New default (no binning), binning on decimals, and faster implementation [#1012](https://github.com/CLIMADA-project/climada_python/pull/1012)

climada/entity/impact_funcs/base.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ def __init__(
9797
self.mdd = mdd if mdd is not None else np.array([])
9898
self.paa = paa if paa is not None else np.array([])
9999

100+
def __eq__(self, value: object, /) -> bool:
101+
if isinstance(value, ImpactFunc):
102+
return (
103+
self.haz_type == value.haz_type
104+
and self.id == value.id
105+
and self.name == value.name
106+
and self.intensity_unit == value.intensity_unit
107+
and np.array_equal(self.intensity, value.intensity)
108+
and np.array_equal(self.mdd, value.mdd)
109+
and np.array_equal(self.paa, value.paa)
110+
)
111+
return False
112+
100113
def calc_mdr(self, inten: Union[float, np.ndarray]) -> np.ndarray:
101114
"""Interpolate impact function to a given intensity.
102115
@@ -177,7 +190,7 @@ def from_step_impf(
177190
mdd: tuple[float, float] = (0, 1),
178191
paa: tuple[float, float] = (1, 1),
179192
impf_id: int = 1,
180-
**kwargs
193+
**kwargs,
181194
):
182195
"""Step function type impact function.
183196
@@ -218,7 +231,7 @@ def from_step_impf(
218231
intensity=intensity,
219232
mdd=mdd,
220233
paa=paa,
221-
**kwargs
234+
**kwargs,
222235
)
223236

224237
def set_step_impf(self, *args, **kwargs):
@@ -238,7 +251,7 @@ def from_sigmoid_impf(
238251
x0: float,
239252
haz_type: str,
240253
impf_id: int = 1,
241-
**kwargs
254+
**kwargs,
242255
):
243256
r"""Sigmoid type impact function hinging on three parameter.
244257
@@ -287,7 +300,7 @@ def from_sigmoid_impf(
287300
intensity=intensity,
288301
paa=paa,
289302
mdd=mdd,
290-
**kwargs
303+
**kwargs,
291304
)
292305

293306
def set_sigmoid_impf(self, *args, **kwargs):
@@ -308,7 +321,7 @@ def from_poly_s_shape(
308321
exponent: float,
309322
haz_type: str,
310323
impf_id: int = 1,
311-
**kwargs
324+
**kwargs,
312325
):
313326
r"""S-shape polynomial impact function hinging on four parameter.
314327

climada/entity/impact_funcs/impact_func_set.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def __init__(self, impact_funcs: Optional[Iterable[ImpactFunc]] = None):
109109
for impf in impact_funcs:
110110
self.append(impf)
111111

112+
def __eq__(self, value: object, /) -> bool:
113+
if isinstance(value, ImpactFuncSet):
114+
return self._data == value._data
115+
116+
return False
117+
112118
def clear(self):
113119
"""Reinitialize attributes."""
114120
self._data = dict() # {hazard_type : {id:ImpactFunc}}

climada/entity/impact_funcs/test/test_base.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,74 @@
2626
from climada.entity.impact_funcs.base import ImpactFunc
2727

2828

29+
class TestEquality(unittest.TestCase):
30+
"""Test equality method"""
31+
32+
def setUp(self):
33+
self.impf1 = ImpactFunc(
34+
haz_type="TC",
35+
id=1,
36+
intensity=np.array([1, 2, 3]),
37+
mdd=np.array([0.1, 0.2, 0.3]),
38+
paa=np.array([0.4, 0.5, 0.6]),
39+
intensity_unit="m/s",
40+
name="Test Impact",
41+
)
42+
self.impf2 = ImpactFunc(
43+
haz_type="TC",
44+
id=1,
45+
intensity=np.array([1, 2, 3]),
46+
mdd=np.array([0.1, 0.2, 0.3]),
47+
paa=np.array([0.4, 0.5, 0.6]),
48+
intensity_unit="m/s",
49+
name="Test Impact",
50+
)
51+
self.impf3 = ImpactFunc(
52+
haz_type="FL",
53+
id=2,
54+
intensity=np.array([4, 5, 6]),
55+
mdd=np.array([0.7, 0.8, 0.9]),
56+
paa=np.array([0.1, 0.2, 0.3]),
57+
intensity_unit="m",
58+
name="Another Impact",
59+
)
60+
61+
def test_reflexivity(self):
62+
self.assertEqual(self.impf1, self.impf1)
63+
64+
def test_symmetry(self):
65+
self.assertEqual(self.impf1, self.impf2)
66+
self.assertEqual(self.impf2, self.impf1)
67+
68+
def test_transitivity(self):
69+
impf4 = ImpactFunc(
70+
haz_type="TC",
71+
id=1,
72+
intensity=np.array([1, 2, 3]),
73+
mdd=np.array([0.1, 0.2, 0.3]),
74+
paa=np.array([0.4, 0.5, 0.6]),
75+
intensity_unit="m/s",
76+
name="Test Impact",
77+
)
78+
self.assertEqual(self.impf1, self.impf2)
79+
self.assertEqual(self.impf2, impf4)
80+
self.assertEqual(self.impf1, impf4)
81+
82+
def test_consistency(self):
83+
self.assertEqual(self.impf1, self.impf2)
84+
self.assertEqual(self.impf1, self.impf2)
85+
86+
def test_comparison_with_none(self):
87+
self.assertNotEqual(self.impf1, None)
88+
89+
def test_different_types(self):
90+
self.assertNotEqual(self.impf1, "Not an ImpactFunc")
91+
92+
def test_inequality(self):
93+
self.assertNotEqual(self.impf1, self.impf3)
94+
self.assertTrue(self.impf1 != self.impf3)
95+
96+
2997
class TestInterpolation(unittest.TestCase):
3098
"""Impact function interpolation test"""
3199

@@ -139,5 +207,8 @@ def test_aux_vars(impf):
139207

140208
# Execute Tests
141209
if __name__ == "__main__":
142-
TESTS = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation)
143-
unittest.TextTestRunner(verbosity=2).run(TESTS)
210+
equality_tests = unittest.TestLoader().loadTestsFromTestCase(TestEquality)
211+
interpolation_tests = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation)
212+
unittest.TextTestRunner(verbosity=2).run(
213+
unittest.TestSuite([equality_tests, interpolation_tests])
214+
)

climada/entity/impact_funcs/test/test_imp_fun_set.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import unittest
23+
from copy import deepcopy
2324

2425
import numpy as np
2526

@@ -288,6 +289,55 @@ def test_remove_add_pass(self):
288289
self.assertEqual([1], imp_fun.get_ids("TC"))
289290

290291

292+
class TestEquality(unittest.TestCase):
293+
"""Test equality method for ImpactFuncSet"""
294+
295+
def setUp(self):
296+
intensity = np.array([0, 20])
297+
paa = np.array([0, 1])
298+
mdd = np.array([0, 0.5])
299+
300+
fun_1 = ImpactFunc("TC", 3, intensity, mdd, paa)
301+
fun_2 = ImpactFunc("TC", 3, deepcopy(intensity), deepcopy(mdd), deepcopy(paa))
302+
fun_3 = ImpactFunc("TC", 4, intensity + 1, mdd, paa)
303+
304+
self.impact_set1 = ImpactFuncSet([fun_1])
305+
self.impact_set2 = ImpactFuncSet([fun_2])
306+
self.impact_set3 = ImpactFuncSet([fun_3])
307+
self.impact_set4 = ImpactFuncSet([fun_1, fun_3])
308+
309+
def test_reflexivity(self):
310+
self.assertEqual(self.impact_set1, self.impact_set1)
311+
312+
def test_symmetry(self):
313+
self.assertEqual(self.impact_set1, self.impact_set2)
314+
self.assertEqual(self.impact_set2, self.impact_set1)
315+
316+
def test_transitivity(self):
317+
impact_set5 = ImpactFuncSet([self.impact_set1._data["TC"][3]])
318+
self.assertEqual(self.impact_set1, self.impact_set2)
319+
self.assertEqual(self.impact_set2, impact_set5)
320+
self.assertEqual(self.impact_set1, impact_set5)
321+
322+
def test_consistency(self):
323+
self.assertEqual(self.impact_set1, self.impact_set2)
324+
self.assertEqual(self.impact_set1, self.impact_set2)
325+
326+
def test_comparison_with_none(self):
327+
self.assertNotEqual(self.impact_set1, None)
328+
329+
def test_different_types(self):
330+
self.assertNotEqual(self.impact_set1, "Not an ImpactFuncSet")
331+
332+
def test_field_comparison(self):
333+
self.assertNotEqual(self.impact_set1, self.impact_set3)
334+
self.assertNotEqual(self.impact_set1, self.impact_set4)
335+
336+
def test_inequality(self):
337+
self.assertNotEqual(self.impact_set1, self.impact_set3)
338+
self.assertTrue(self.impact_set1 != self.impact_set3)
339+
340+
291341
class TestChecker(unittest.TestCase):
292342
"""Test loading funcions from the ImpactFuncSet class"""
293343

@@ -592,6 +642,7 @@ def test_write_read_pass(self):
592642
# Execute Tests
593643
if __name__ == "__main__":
594644
TESTS = unittest.TestLoader().loadTestsFromTestCase(TestContainer)
645+
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestEquality))
595646
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestChecker))
596647
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestExtend))
597648
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestReaderExcel))

0 commit comments

Comments
 (0)