Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maximize compatability with Datatypes by returning NotImplemented if __add__, __mul__ ... fail #417

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
Release Notes
=============

.. Upcoming Version
.. ----------------
Upcoming Version
----------------

* Added support for arithmetic operations with custom classes.

Version 0.5.0
--------------
Expand Down
103 changes: 61 additions & 42 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,14 @@
Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if np.isscalar(other):
return self.assign(const=self.const + other)
try:
if np.isscalar(other):
return self.assign(const=self.const + other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
return merge([self, other], cls=self.__class__)
other = as_expression(other, model=self.model, dims=self.coord_dims)
return merge([self, other], cls=self.__class__)
except TypeError:
return NotImplemented

def __radd__(self, other: int) -> LinearExpression | NotImplementedType:
# This is needed for using python's sum function
Expand All @@ -505,11 +508,14 @@
Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if np.isscalar(other):
return self.assign_multiindex_safe(const=self.const - other)
try:
if np.isscalar(other):
return self.assign_multiindex_safe(const=self.const - other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
return merge([self, -other], cls=self.__class__)
other = as_expression(other, model=self.model, dims=self.coord_dims)
return merge([self, -other], cls=self.__class__)
except TypeError:
return NotImplemented

def __neg__(self) -> LinearExpression | QuadraticExpression:
"""
Expand All @@ -524,19 +530,22 @@
"""
Multiply the expr by a factor.
"""
if isinstance(other, QuadraticExpression):
raise TypeError(
"unsupported operand type(s) for *: "
f"{type(self)} and {type(other)}. "
"Higher order non-linear expressions are not yet supported."
)
elif isinstance(other, (variables.Variable, variables.ScalarVariable)):
other = other.to_linexpr()
try:
if isinstance(other, QuadraticExpression):
raise TypeError(

Check warning on line 535 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L535

Added line #L535 was not covered by tests
"unsupported operand type(s) for *: "
f"{type(self)} and {type(other)}. "
"Higher order non-linear expressions are not yet supported."
)
elif isinstance(other, (variables.Variable, variables.ScalarVariable)):
other = other.to_linexpr()

if isinstance(other, (LinearExpression, ScalarLinearExpression)):
return self._multiply_by_linear_expression(other)
else:
return self._multiply_by_constant(other)
if isinstance(other, (LinearExpression, ScalarLinearExpression)):
return self._multiply_by_linear_expression(other)
else:
return self._multiply_by_constant(other)
except TypeError:
return NotImplemented

def _multiply_by_linear_expression(
self, other: LinearExpression | ScalarLinearExpression
Expand Down Expand Up @@ -599,15 +608,18 @@
def __div__(
self, other: Variable | ConstantLike
) -> LinearExpression | QuadraticExpression:
if isinstance(
other, (LinearExpression, variables.Variable, variables.ScalarVariable)
):
raise TypeError(
"unsupported operand type(s) for /: "
f"{type(self)} and {type(other)}"
"Non-linear expressions are not yet supported."
)
return self.__mul__(1 / other)
try:
if isinstance(
other, (LinearExpression, variables.Variable, variables.ScalarVariable)
):
raise TypeError(
"unsupported operand type(s) for /: "
f"{type(self)} and {type(other)}"
"Non-linear expressions are not yet supported."
)
return self.__mul__(1 / other)
except TypeError:
return NotImplemented

def __truediv__(
self, other: Variable | ConstantLike
Expand Down Expand Up @@ -1557,13 +1569,17 @@
Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if np.isscalar(other):
return self.assign(const=self.const + other)
try:
if np.isscalar(other):
return self.assign(const=self.const + other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, other], cls=self.__class__) # type: ignore
other = as_expression(other, model=self.model, dims=self.coord_dims)

if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, other], cls=self.__class__) # type: ignore
except TypeError:
return NotImplemented

Check warning on line 1582 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L1581-L1582

Added lines #L1581 - L1582 were not covered by tests

def __radd__(
self, other: LinearExpression | int
Expand All @@ -1586,13 +1602,16 @@
Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if np.isscalar(other):
return self.assign(const=self.const - other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, -other], cls=self.__class__) # type: ignore
try:
if np.isscalar(other):
return self.assign(const=self.const - other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, -other], cls=self.__class__) # type: ignore
except TypeError:
return NotImplemented

Check warning on line 1614 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L1613-L1614

Added lines #L1613 - L1614 were not covered by tests

def __rsub__(self, other: LinearExpression) -> QuadraticExpression:
"""
Expand Down
39 changes: 28 additions & 11 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,25 +388,33 @@
"""
Multiply variables with a coefficient.
"""
if isinstance(other, (expressions.LinearExpression, Variable, ScalarVariable)):
return self.to_linexpr() * other
else:
try:
if isinstance(
other, (expressions.LinearExpression, Variable, ScalarVariable)
):
return self.to_linexpr() * other

return self.to_linexpr(other)
except TypeError:
return NotImplemented

def __pow__(self, other: int) -> QuadraticExpression:
"""
Power of the variables with a coefficient. The only coefficient allowed is 2.
"""
if not other == 2:
raise ValueError("Power must be 2.")
expr = self.to_linexpr()
return expr._multiply_by_linear_expression(expr)
if isinstance(other, int) and other == 2:
expr = self.to_linexpr()
return expr._multiply_by_linear_expression(expr)
return NotImplemented

def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression:
"""
Right-multiply variables with a coefficient.
"""
return self.to_linexpr(other)
try:
return self.to_linexpr(other)
except TypeError:
return NotImplemented

Check warning on line 417 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L416-L417

Added lines #L416 - L417 were not covered by tests

def __matmul__(
self, other: LinearExpression | ndarray | Variable
Expand Down Expand Up @@ -436,15 +444,21 @@
"""
True divide variables with a coefficient.
"""
return self.__div__(coefficient)
try:
return self.__div__(coefficient)
except TypeError:
return NotImplemented

def __add__(
self, other: int | QuadraticExpression | LinearExpression | Variable
) -> QuadraticExpression | LinearExpression:
"""
Add variables to linear expressions or other variables.
"""
return self.to_linexpr() + other
try:
return self.to_linexpr() + other
except TypeError:
return NotImplemented

def __radd__(self, other: int) -> Variable | NotImplementedType:
# This is needed for using python's sum function
Expand All @@ -456,7 +470,10 @@
"""
Subtract linear expressions or other variables from the variables.
"""
return self.to_linexpr() - other
try:
return self.to_linexpr() - other
except TypeError:
return NotImplemented

def __le__(self, other: SideLike) -> Constraint:
return self.to_linexpr().__le__(other)
Expand Down
139 changes: 139 additions & 0 deletions test/test_compatible_arithmetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Any

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from linopy import LESS_EQUAL, Model, Variable
from linopy.testing import assert_linequal


class SomeOtherDatatype:
"""
A class that is not a subclass of xarray.DataArray, but stores data in a compatible way.
It defines all necessary arithmetrics AND __array_ufunc__ to ensure that operations are
performed on the active_data.
"""

def __init__(self, data: xr.DataArray) -> None:
self.data1 = data
self.data2 = data.copy()
self.active = 1

def activate(self, active: int) -> None:
self.active = active

@property
def active_data(self) -> xr.DataArray:
return self.data1 if self.active == 1 else self.data2

def __add__(self, other: Any) -> xr.DataArray:
return self.active_data + other

def __sub__(self, other: Any) -> xr.DataArray:
return self.active_data - other

def __mul__(self, other: Any) -> xr.DataArray:
return self.active_data * other

def __truediv__(self, other: Any) -> xr.DataArray:
return self.active_data / other

def __radd__(self, other: Any) -> Any:
return other + self.active_data

def __rsub__(self, other: Any) -> Any:
return other - self.active_data

def __rmul__(self, other: Any) -> Any:
return other * self.active_data

def __rtruediv__(self, other: Any) -> Any:
return other / self.active_data

def __neg__(self) -> xr.DataArray:
return -self.active_data

def __pos__(self) -> xr.DataArray:
return +self.active_data

def __abs__(self) -> xr.DataArray:
return abs(self.active_data)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # type: ignore
# Ensure we always use the active_data when interacting with numpy/xarray operations
new_inputs = [
inp.active_data if isinstance(inp, SomeOtherDatatype) else inp
for inp in inputs
]
return getattr(ufunc, method)(*new_inputs, **kwargs)


@pytest.fixture(
params=[
(pd.RangeIndex(10, name="first"),),
(
pd.Index(range(5), name="first"),
pd.Index(range(3), name="second"),
pd.Index(range(2), name="third"),
),
],
ids=["single_dim", "multi_dim"],
)
def m(request) -> Model: # type: ignore
m = Model()
x = m.add_variables(coords=request.param, name="x")
m.add_variables(0, 10, name="z")
m.add_constraints(x, LESS_EQUAL, 0, name="c")
return m


def test_arithmetric_operations_variable(m: Model) -> None:
x: Variable = m.variables["x"]
rng = np.random.default_rng()
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
other_datatype = SomeOtherDatatype(data.copy())
assert_linequal(x + data, x + other_datatype) # type: ignore
assert_linequal(x - data, x - other_datatype) # type: ignore
assert_linequal(x * data, x * other_datatype) # type: ignore
assert_linequal(x / data, x / other_datatype) # type: ignore
assert_linequal(data * x, other_datatype * x) # type: ignore
assert x.__add__(object()) is NotImplemented # type: ignore
assert x.__sub__(object()) is NotImplemented # type: ignore
assert x.__mul__(object()) is NotImplemented # type: ignore
assert x.__truediv__(object()) is NotImplemented # type: ignore
assert x.__pow__(object()) is NotImplemented # type: ignore
assert x.__pow__(3) is NotImplemented


def test_arithmetric_operations_expr(m: Model) -> None:
x = m.variables["x"]
expr = x + 3
rng = np.random.default_rng()
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
other_datatype = SomeOtherDatatype(data.copy())
assert_linequal(expr + data, expr + other_datatype)
assert_linequal(expr - data, expr - other_datatype)
assert_linequal(expr * data, expr * other_datatype)
assert_linequal(expr / data, expr / other_datatype)
assert expr.__add__(object()) is NotImplemented
assert expr.__sub__(object()) is NotImplemented
assert expr.__mul__(object()) is NotImplemented
assert expr.__truediv__(object()) is NotImplemented


def test_arithmetric_operations_con(m: Model) -> None:
c = m.constraints["c"]
x = m.variables["x"]
rng = np.random.default_rng()
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
other_datatype = SomeOtherDatatype(data.copy())
assert_linequal(c.lhs + data, c.lhs + other_datatype)
assert_linequal(c.lhs - data, c.lhs - other_datatype)
assert_linequal(c.lhs * data, c.lhs * other_datatype)
assert_linequal(c.lhs / data, c.lhs / other_datatype)
assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore
assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore
assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore
assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore
Loading