Skip to content

Commit 084d112

Browse files
committed
add typehints and ignore types to test to satisfy mypy
1 parent 615fb4d commit 084d112

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

test/test_compatible_arithmetrics.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import Any
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
46
import xarray as xr
57

6-
from linopy import Model
8+
from linopy import LESS_EQUAL, Model, Variable
79
from linopy.testing import assert_linequal
810

911

@@ -26,40 +28,40 @@ def activate(self, active: int) -> None:
2628
def active_data(self) -> xr.DataArray:
2729
return self.data1 if self.active == 1 else self.data2
2830

29-
def __add__(self, other):
31+
def __add__(self, other: Any) -> xr.DataArray:
3032
return self.active_data + other
3133

32-
def __sub__(self, other):
34+
def __sub__(self, other: Any) -> xr.DataArray:
3335
return self.active_data - other
3436

35-
def __mul__(self, other):
37+
def __mul__(self, other: Any) -> xr.DataArray:
3638
return self.active_data * other
3739

38-
def __truediv__(self, other):
40+
def __truediv__(self, other: Any) -> xr.DataArray:
3941
return self.active_data / other
4042

41-
def __radd__(self, other):
43+
def __radd__(self, other: Any) -> Any:
4244
return other + self.active_data
4345

44-
def __rsub__(self, other):
46+
def __rsub__(self, other: Any) -> Any:
4547
return other - self.active_data
4648

47-
def __rmul__(self, other):
49+
def __rmul__(self, other: Any) -> Any:
4850
return other * self.active_data
4951

50-
def __rtruediv__(self, other):
52+
def __rtruediv__(self, other: Any) -> Any:
5153
return other / self.active_data
5254

53-
def __neg__(self):
55+
def __neg__(self) -> xr.DataArray:
5456
return -self.active_data
5557

56-
def __pos__(self):
58+
def __pos__(self) -> xr.DataArray:
5759
return +self.active_data
5860

59-
def __abs__(self):
61+
def __abs__(self) -> xr.DataArray:
6062
return abs(self.active_data)
6163

62-
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
64+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # type: ignore
6365
# Ensure we always use the active_data when interacting with numpy/xarray operations
6466
new_inputs = [
6567
inp.active_data if isinstance(inp, SomeOtherDatatype) else inp
@@ -79,23 +81,23 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
7981
],
8082
ids=["single_dim", "multi_dim"],
8183
)
82-
def m(request) -> Model:
84+
def m(request) -> Model: # type: ignore
8385
m = Model()
84-
m.add_variables(coords=request.param, name="x")
86+
x = m.add_variables(coords=request.param, name="x")
8587
m.add_variables(0, 10, name="z")
86-
m.add_constraints(m.variables["x"] >= 0, name="c")
88+
m.add_constraints(x, LESS_EQUAL, 0, name="c")
8789
return m
8890

8991

9092
def test_arithmetric_operations_variable(m: Model) -> None:
91-
x = m.variables["x"]
93+
x: Variable = m.variables["x"]
9294
rng = np.random.default_rng()
9395
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
9496
other_datatype = SomeOtherDatatype(data.copy())
95-
assert_linequal(x + data, x + other_datatype)
96-
assert_linequal(x - data, x - other_datatype)
97-
assert_linequal(x * data, x * other_datatype)
98-
assert_linequal(x / data, x / other_datatype)
97+
assert_linequal(x + data, x + other_datatype) # type: ignore
98+
assert_linequal(x - data, x - other_datatype) # type: ignore
99+
assert_linequal(x * data, x * other_datatype) # type: ignore
100+
assert_linequal(x / data, x / other_datatype) # type: ignore
99101

100102

101103
def test_arithmetric_operations_con(m: Model) -> None:
@@ -108,7 +110,7 @@ def test_arithmetric_operations_con(m: Model) -> None:
108110
assert_linequal(c.lhs - data, c.lhs - other_datatype)
109111
assert_linequal(c.lhs * data, c.lhs * other_datatype)
110112
assert_linequal(c.lhs / data, c.lhs / other_datatype)
111-
assert_linequal(c.rhs + data, c.rhs + other_datatype)
112-
assert_linequal(c.rhs - data, c.rhs - other_datatype)
113-
assert_linequal(c.rhs * data, c.rhs * other_datatype)
114-
assert_linequal(c.rhs / data, c.rhs / other_datatype)
113+
assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore
114+
assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore
115+
assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore
116+
assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore

0 commit comments

Comments
 (0)