Skip to content

Commit 3ace2fb

Browse files
max-sixtypre-commit-ci[bot]headtr1ck
authored
Use Self rather than concrete types, remove casts (#8216)
* Use `Self` rather than concrete types, remove `cast`s This should also allow for subtyping * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Undo one `Self` * Unused ignore * Add check for redundant self annotations * And `DataWithCoords` * And `DataArray` & `Dataset` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * And `Variable` * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas <[email protected]> * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas <[email protected]> * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas <[email protected]> * Clean-ups — `other`, casts, obsolete comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * another one --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas <[email protected]>
1 parent 96cf77a commit 3ace2fb

File tree

9 files changed

+400
-433
lines changed

9 files changed

+400
-433
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ source = ["xarray"]
7272
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
7373

7474
[tool.mypy]
75+
enable_error_code = "redundant-self"
7576
exclude = 'xarray/util/generate_.*\.py'
7677
files = "xarray"
7778
show_error_codes = true

xarray/core/accessor_str.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,7 @@ def _partitioner(
23862386

23872387
# _apply breaks on an empty array in this case
23882388
if not self._obj.size:
2389-
return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value]
2389+
return self._obj.copy().expand_dims({dim: 0}, axis=-1)
23902390

23912391
arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype)
23922392

xarray/core/common.py

+15-22
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
DatetimeLike,
4646
DTypeLikeSave,
4747
ScalarOrArray,
48+
Self,
4849
SideOptions,
4950
T_Chunks,
5051
T_DataWithCoords,
@@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin):
381382
__slots__ = ("_close",)
382383

383384
def squeeze(
384-
self: T_DataWithCoords,
385+
self,
385386
dim: Hashable | Iterable[Hashable] | None = None,
386387
drop: bool = False,
387388
axis: int | Iterable[int] | None = None,
388-
) -> T_DataWithCoords:
389+
) -> Self:
389390
"""Return a new object with squeezed data.
390391
391392
Parameters
@@ -414,12 +415,12 @@ def squeeze(
414415
return self.isel(drop=drop, **{d: 0 for d in dims})
415416

416417
def clip(
417-
self: T_DataWithCoords,
418+
self,
418419
min: ScalarOrArray | None = None,
419420
max: ScalarOrArray | None = None,
420421
*,
421422
keep_attrs: bool | None = None,
422-
) -> T_DataWithCoords:
423+
) -> Self:
423424
"""
424425
Return an array whose values are limited to ``[min, max]``.
425426
At least one of max or min must be given.
@@ -472,10 +473,10 @@ def _calc_assign_results(
472473
return {k: v(self) if callable(v) else v for k, v in kwargs.items()}
473474

474475
def assign_coords(
475-
self: T_DataWithCoords,
476+
self,
476477
coords: Mapping[Any, Any] | None = None,
477478
**coords_kwargs: Any,
478-
) -> T_DataWithCoords:
479+
) -> Self:
479480
"""Assign new coordinates to this object.
480481
481482
Returns a new object with all the original data in addition to the new
@@ -620,9 +621,7 @@ def assign_coords(
620621
data.coords.update(results)
621622
return data
622623

623-
def assign_attrs(
624-
self: T_DataWithCoords, *args: Any, **kwargs: Any
625-
) -> T_DataWithCoords:
624+
def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
626625
"""Assign new attrs to this object.
627626
628627
Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``.
@@ -1061,9 +1060,7 @@ def _resample(
10611060
restore_coord_dims=restore_coord_dims,
10621061
)
10631062

1064-
def where(
1065-
self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False
1066-
) -> T_DataWithCoords:
1063+
def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
10671064
"""Filter elements from this object according to a condition.
10681065
10691066
Returns elements from 'DataArray', where 'cond' is True,
@@ -1208,9 +1205,7 @@ def close(self) -> None:
12081205
self._close()
12091206
self._close = None
12101207

1211-
def isnull(
1212-
self: T_DataWithCoords, keep_attrs: bool | None = None
1213-
) -> T_DataWithCoords:
1208+
def isnull(self, keep_attrs: bool | None = None) -> Self:
12141209
"""Test each value in the array for whether it is a missing value.
12151210
12161211
Parameters
@@ -1253,9 +1248,7 @@ def isnull(
12531248
keep_attrs=keep_attrs,
12541249
)
12551250

1256-
def notnull(
1257-
self: T_DataWithCoords, keep_attrs: bool | None = None
1258-
) -> T_DataWithCoords:
1251+
def notnull(self, keep_attrs: bool | None = None) -> Self:
12591252
"""Test each value in the array for whether it is not a missing value.
12601253
12611254
Parameters
@@ -1298,7 +1291,7 @@ def notnull(
12981291
keep_attrs=keep_attrs,
12991292
)
13001293

1301-
def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
1294+
def isin(self, test_elements: Any) -> Self:
13021295
"""Tests each value in the array for whether it is in test elements.
13031296
13041297
Parameters
@@ -1347,15 +1340,15 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
13471340
)
13481341

13491342
def astype(
1350-
self: T_DataWithCoords,
1343+
self,
13511344
dtype,
13521345
*,
13531346
order=None,
13541347
casting=None,
13551348
subok=None,
13561349
copy=None,
13571350
keep_attrs=True,
1358-
) -> T_DataWithCoords:
1351+
) -> Self:
13591352
"""
13601353
Copy of the xarray object, with data cast to a specified type.
13611354
Leaves coordinate dtype unchanged.
@@ -1422,7 +1415,7 @@ def astype(
14221415
dask="allowed",
14231416
)
14241417

1425-
def __enter__(self: T_DataWithCoords) -> T_DataWithCoords:
1418+
def __enter__(self) -> Self:
14261419
return self
14271420

14281421
def __exit__(self, exc_type, exc_value, traceback) -> None:

xarray/core/concat.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Hashable, Iterable
4-
from typing import TYPE_CHECKING, Any, Union, cast, overload
4+
from typing import TYPE_CHECKING, Any, Union, overload
55

66
import numpy as np
77
import pandas as pd
@@ -504,8 +504,7 @@ def _dataset_concat(
504504

505505
# case where concat dimension is a coordinate or data_var but not a dimension
506506
if (dim in coord_names or dim in data_names) and dim not in dim_names:
507-
# TODO: Overriding type because .expand_dims has incorrect typing:
508-
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]
507+
datasets = [ds.expand_dims(dim) for ds in datasets]
509508

510509
# determine which variables to concatenate
511510
concat_over, equals, concat_dim_lengths = _calc_concat_over(
@@ -708,8 +707,7 @@ def _dataarray_concat(
708707
if compat == "identical":
709708
raise ValueError("array names not identical")
710709
else:
711-
# TODO: Overriding type because .rename has incorrect typing:
712-
arr = cast(T_DataArray, arr.rename(name))
710+
arr = arr.rename(name)
713711
datasets.append(arr._to_temp_dataset())
714712

715713
ds = _dataset_concat(

xarray/core/coordinates.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
create_default_index_implicit,
2424
)
2525
from xarray.core.merge import merge_coordinates_without_align, merge_coords
26-
from xarray.core.types import Self, T_DataArray
26+
from xarray.core.types import Self, T_DataArray, T_Xarray
2727
from xarray.core.utils import (
2828
Frozen,
2929
ReprObject,
@@ -425,7 +425,7 @@ def __delitem__(self, key: Hashable) -> None:
425425
# redirect to DatasetCoordinates.__delitem__
426426
del self._data.coords[key]
427427

428-
def equals(self, other: Coordinates) -> bool:
428+
def equals(self, other: Self) -> bool:
429429
"""Two Coordinates objects are equal if they have matching variables,
430430
all of which are equal.
431431
@@ -437,7 +437,7 @@ def equals(self, other: Coordinates) -> bool:
437437
return False
438438
return self.to_dataset().equals(other.to_dataset())
439439

440-
def identical(self, other: Coordinates) -> bool:
440+
def identical(self, other: Self) -> bool:
441441
"""Like equals, but also checks all variable attributes.
442442
443443
See Also
@@ -565,9 +565,7 @@ def update(self, other: Mapping[Any, Any]) -> None:
565565

566566
self._update_coords(coords, indexes)
567567

568-
def assign(
569-
self, coords: Mapping | None = None, **coords_kwargs: Any
570-
) -> Coordinates:
568+
def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self:
571569
"""Assign new coordinates (and indexes) to a Coordinates object, returning
572570
a new object with all the original coordinates in addition to the new ones.
573571
@@ -656,16 +654,24 @@ def copy(
656654
self,
657655
deep: bool = False,
658656
memo: dict[int, Any] | None = None,
659-
) -> Coordinates:
657+
) -> Self:
660658
"""Return a copy of this Coordinates object."""
661659
# do not copy indexes (may corrupt multi-coordinate indexes)
662660
# TODO: disable variables deepcopy? it may also be problematic when they
663661
# encapsulate index objects like pd.Index
664662
variables = {
665663
k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items()
666664
}
667-
return Coordinates._construct_direct(
668-
coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
665+
666+
# TODO: getting an error with `self._construct_direct`, possibly because of how
667+
# a subclass implements `_construct_direct`. (This was originally the same
668+
# runtime code, but we switched the type definitions in #8216, which
669+
# necessitates the cast.)
670+
return cast(
671+
Self,
672+
Coordinates._construct_direct(
673+
coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
674+
),
669675
)
670676

671677

@@ -915,9 +921,7 @@ def drop_indexed_coords(
915921
return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes)
916922

917923

918-
def assert_coordinate_consistent(
919-
obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
920-
) -> None:
924+
def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None:
921925
"""Make sure the dimension coordinate of obj is consistent with coords.
922926
923927
obj: DataArray or Dataset

0 commit comments

Comments
 (0)