1
+ from typing import Any
2
+
1
3
import numpy as np
2
4
import pandas as pd
3
5
import pytest
4
6
import xarray as xr
5
7
6
- from linopy import Model
8
+ from linopy import LESS_EQUAL , Model , Variable
7
9
from linopy .testing import assert_linequal
8
10
9
11
@@ -26,40 +28,40 @@ def activate(self, active: int) -> None:
26
28
def active_data (self ) -> xr .DataArray :
27
29
return self .data1 if self .active == 1 else self .data2
28
30
29
- def __add__ (self , other ) :
31
+ def __add__ (self , other : Any ) -> xr . DataArray :
30
32
return self .active_data + other
31
33
32
- def __sub__ (self , other ) :
34
+ def __sub__ (self , other : Any ) -> xr . DataArray :
33
35
return self .active_data - other
34
36
35
- def __mul__ (self , other ) :
37
+ def __mul__ (self , other : Any ) -> xr . DataArray :
36
38
return self .active_data * other
37
39
38
- def __truediv__ (self , other ) :
40
+ def __truediv__ (self , other : Any ) -> xr . DataArray :
39
41
return self .active_data / other
40
42
41
- def __radd__ (self , other ) :
43
+ def __radd__ (self , other : Any ) -> Any :
42
44
return other + self .active_data
43
45
44
- def __rsub__ (self , other ) :
46
+ def __rsub__ (self , other : Any ) -> Any :
45
47
return other - self .active_data
46
48
47
- def __rmul__ (self , other ) :
49
+ def __rmul__ (self , other : Any ) -> Any :
48
50
return other * self .active_data
49
51
50
- def __rtruediv__ (self , other ) :
52
+ def __rtruediv__ (self , other : Any ) -> Any :
51
53
return other / self .active_data
52
54
53
- def __neg__ (self ):
55
+ def __neg__ (self ) -> xr . DataArray :
54
56
return - self .active_data
55
57
56
- def __pos__ (self ):
58
+ def __pos__ (self ) -> xr . DataArray :
57
59
return + self .active_data
58
60
59
- def __abs__ (self ):
61
+ def __abs__ (self ) -> xr . DataArray :
60
62
return abs (self .active_data )
61
63
62
- def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ):
64
+ def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ): # type: ignore
63
65
# Ensure we always use the active_data when interacting with numpy/xarray operations
64
66
new_inputs = [
65
67
inp .active_data if isinstance (inp , SomeOtherDatatype ) else inp
@@ -79,23 +81,23 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
79
81
],
80
82
ids = ["single_dim" , "multi_dim" ],
81
83
)
82
- def m (request ) -> Model :
84
+ def m (request ) -> Model : # type: ignore
83
85
m = Model ()
84
- m .add_variables (coords = request .param , name = "x" )
86
+ x = m .add_variables (coords = request .param , name = "x" )
85
87
m .add_variables (0 , 10 , name = "z" )
86
- m .add_constraints (m . variables [ "x" ] >= 0 , name = "c" )
88
+ m .add_constraints (x , LESS_EQUAL , 0 , name = "c" )
87
89
return m
88
90
89
91
90
92
def test_arithmetric_operations_variable (m : Model ) -> None :
91
- x = m .variables ["x" ]
93
+ x : Variable = m .variables ["x" ]
92
94
rng = np .random .default_rng ()
93
95
data = xr .DataArray (rng .random (x .shape ), coords = x .coords )
94
96
other_datatype = SomeOtherDatatype (data .copy ())
95
- assert_linequal (x + data , x + other_datatype )
96
- assert_linequal (x - data , x - other_datatype )
97
- assert_linequal (x * data , x * other_datatype )
98
- assert_linequal (x / data , x / other_datatype )
97
+ assert_linequal (x + data , x + other_datatype ) # type: ignore
98
+ assert_linequal (x - data , x - other_datatype ) # type: ignore
99
+ assert_linequal (x * data , x * other_datatype ) # type: ignore
100
+ assert_linequal (x / data , x / other_datatype ) # type: ignore
99
101
100
102
101
103
def test_arithmetric_operations_con (m : Model ) -> None :
@@ -108,7 +110,7 @@ def test_arithmetric_operations_con(m: Model) -> None:
108
110
assert_linequal (c .lhs - data , c .lhs - other_datatype )
109
111
assert_linequal (c .lhs * data , c .lhs * other_datatype )
110
112
assert_linequal (c .lhs / data , c .lhs / other_datatype )
111
- assert_linequal (c .rhs + data , c .rhs + other_datatype )
112
- assert_linequal (c .rhs - data , c .rhs - other_datatype )
113
- assert_linequal (c .rhs * data , c .rhs * other_datatype )
114
- assert_linequal (c .rhs / data , c .rhs / other_datatype )
113
+ assert_linequal (c .rhs + data , c .rhs + other_datatype ) # type: ignore
114
+ assert_linequal (c .rhs - data , c .rhs - other_datatype ) # type: ignore
115
+ assert_linequal (c .rhs * data , c .rhs * other_datatype ) # type: ignore
116
+ assert_linequal (c .rhs / data , c .rhs / other_datatype ) # type: ignore
0 commit comments