Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: upgrade to ruff v0.8.4 and fix type hints #264

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ repos:
pass_filenames: false
# ruff check (w/autofix)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3 # should match version in pyproject.toml
rev: v0.8.4 # should match version in pyproject.toml
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
# ruff format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3 # should match version in pyproject.toml
rev: v0.8.4 # should match version in pyproject.toml
hooks:
- id: ruff-format
# # pydoclint - docstring formatting
Expand Down
34 changes: 18 additions & 16 deletions benchmarks/plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,13 @@
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.style.use(\"ggplot\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c92bf960-ddb5-409f-bd3c-5bce0a03ccd0",
"metadata": {},
"outputs": [],
"source": [
"from sequentia import"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "6649bf2d-7430-401d-8113-f3c1e1cf4779",
"metadata": {},
"outputs": [
Expand All @@ -48,23 +37,36 @@
"\n",
"bars = ax.bar(labels, runtimes, width=0.5, color=\"C1\")\n",
"ax.set(xlabel=\"Package\", ylabel=\"Runtime (s)\")\n",
"ax.set_title(\"Univariate DTW-kNN performance (1,500 FSDD train/test sequences, 16 workers)\", fontsize=11)\n",
"ax.set_title(\n",
" (\n",
" \"Univariate DTW-kNN performance \"\n",
" \"(1,500 FSDD train/test sequences, 16 workers)\"\n",
" ),\n",
" fontsize=11,\n",
")\n",
"\n",
"\n",
"def fmt(s: float) -> str:\n",
" \"\"\"Formats the runtime.\"\"\"\n",
" if s < 60:\n",
" return f\"{round(s)}s\"\n",
" m, s = divmod(s, 60)\n",
" return f\"{round(m)}m {round(s)}s\"\n",
"\n",
"\n",
"for bar in bars:\n",
" plt.text(\n",
" bar.get_x() + bar.get_width() / 2, bar.get_height(),\n",
" fmt(bar.get_height()), ha='center', va='bottom', fontsize=9,\n",
" bar.get_x() + bar.get_width() / 2,\n",
" bar.get_height(),\n",
" fmt(bar.get_height()),\n",
" ha=\"center\",\n",
" va=\"bottom\",\n",
" fontsize=9,\n",
" )\n",
"\n",
"for lab in ax.get_xticklabels():\n",
" if lab.get_text() == \"sequentia\":\n",
" lab.set_fontweight('bold')\n",
" if lab.get_text() == \"sequentia\":\n",
" lab.set_fontweight(\"bold\")\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"benchmark.svg\")\n",
Expand Down
2 changes: 1 addition & 1 deletion make/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check(c: Config) -> None:
def format_(c: Config) -> None:
"""Format Python files."""
commands: list[str] = [
"poetry run ruff --fix .",
"poetry run ruff check --fix .",
"poetry run ruff format .",
]
for command in commands:
Expand Down
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ tox = "4.11.3"
pre-commit = ">=3"

[tool.poetry.group.lint.dependencies]
ruff = "0.1.3"
ruff = "0.8.4"
pydoclint = "0.3.8"

[tool.poetry.group.docs.dependencies]
Expand All @@ -100,8 +100,8 @@ pytest = { version = "^7.4.0" }
pytest-cov = { version = "^4.1.0" }

[tool.ruff]
required-version = "0.1.3"
select = [
required-version = "0.8.4"
lint.select = [
"F", # pyflakes: https://pypi.org/project/pyflakes/
"E", # pycodestyle (error): https://pypi.org/project/pycodestyle/
"W", # pycodestyle (warning): https://pypi.org/project/pycodestyle/
Expand Down Expand Up @@ -144,7 +144,7 @@ select = [
"PERF", # perflint: https://pypi.org/project/perflint/
"RUF", # ruff
]
ignore = [
lint.ignore = [
"ANN401", # https://beta.ruff.rs/docs/rules/any-type/
"B905", # https://beta.ruff.rs/docs/rules/zip-without-explicit-strict/
"TD003", # https://beta.ruff.rs/docs/rules/missing-todo-link/
Expand All @@ -162,16 +162,15 @@ ignore = [
"C408", # Unnecessary `dict` call (rewrite as a literal)
"D401", # First line of docstring should be in imperative mood
]
ignore-init-module-imports = true # allow unused imports in __init__.py
line-length = 79

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.flake8-annotations]
[tool.ruff.lint.flake8-annotations]
allow-star-arg-any = true

[tool.ruff.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["PLC0414", "F403", "F401", "F405"]
"sequentia/datasets/*.py" = ["B006"]
"sequentia/enums.py" = ["E501"]
Expand Down
44 changes: 18 additions & 26 deletions sequentia/_internal/_hmm/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from sequentia.enums import TopologyMode

__all__ = [
"TOPOLOGY_MAP",
"ErgodicTopology",
"LeftRightTopology",
"LinearTopology",
"TOPOLOGY_MAP",
]


Expand All @@ -36,15 +36,15 @@ class BaseTopology:
mode: TopologyMode

def __init__(
self: BaseTopology,
self,
*,
n_states: int,
random_state: np.random.RandomState,
) -> BaseTopology:
) -> None:
self.n_states = n_states
self.random_state = random_state

def uniform_start_probs(self: BaseTopology) -> FloatArray:
def uniform_start_probs(self) -> FloatArray:
"""Set the initial state distribution as a discrete uniform
distribution.

Expand All @@ -55,7 +55,7 @@ def uniform_start_probs(self: BaseTopology) -> FloatArray:
"""
return np.ones(self.n_states) / self.n_states

def random_start_probs(self: BaseTopology) -> FloatArray:
def random_start_probs(self) -> FloatArray:
"""Set the initial state distribution by randomly sampling
probabilities generated by a Dirichlet distribution.

Expand All @@ -69,7 +69,7 @@ def random_start_probs(self: BaseTopology) -> FloatArray:
size=1,
).flatten()

def uniform_transition_probs(self: BaseTopology) -> FloatArray:
def uniform_transition_probs(self) -> FloatArray:
"""Set the transition matrix as uniform (equal probability of
transitioning to all other possible states from each state)
corresponding to the topology.
Expand All @@ -81,7 +81,7 @@ def uniform_transition_probs(self: BaseTopology) -> FloatArray:
"""
raise NotImplementedError

def random_transition_probs(self: BaseTopology) -> FloatArray:
def random_transition_probs(self) -> FloatArray:
"""Set the transition matrix as random (random probability of
transitioning to all other possible states from each state) by
sampling probabilitiesfrom a Dirichlet distribution - according
Expand All @@ -94,7 +94,7 @@ def random_transition_probs(self: BaseTopology) -> FloatArray:
"""
raise NotImplementedError

def check_start_probs(self: BaseTopology, initial: FloatArray, /) -> None:
def check_start_probs(self, initial: FloatArray, /) -> None:
"""Validate an initial state distribution according to the
topology's restrictions.

Expand All @@ -114,9 +114,7 @@ def check_start_probs(self: BaseTopology, initial: FloatArray, /) -> None:
raise ValueError(msg)
return initial

def check_transition_probs(
self: BaseTopology, transitions: FloatArray, /
) -> FloatArray:
def check_transition_probs(self, transitions: FloatArray, /) -> FloatArray:
"""Validate a transition matrix according to the topology's
restrictions.

Expand Down Expand Up @@ -152,7 +150,7 @@ class ErgodicTopology(BaseTopology):

mode: TopologyMode = TopologyMode.ERGODIC

def uniform_transition_probs(self: ErgodicTopology) -> FloatArray:
def uniform_transition_probs(self) -> FloatArray:
"""Set the transition matrix as uniform (equal probability of
transitioning to all other possible states from each state)
corresponding to the topology.
Expand All @@ -164,7 +162,7 @@ def uniform_transition_probs(self: ErgodicTopology) -> FloatArray:
"""
return np.ones((self.n_states, self.n_states)) / self.n_states

def random_transition_probs(self: ErgodicTopology) -> FloatArray:
def random_transition_probs(self) -> FloatArray:
"""Set the transition matrix as random (random probability of
transitioning to all other possible states from each state) by
sampling probabilities from a Dirichlet distribution - according
Expand All @@ -180,9 +178,7 @@ def random_transition_probs(self: ErgodicTopology) -> FloatArray:
size=self.n_states,
)

def check_transition_probs(
self: ErgodicTopology, transitions: FloatArray, /
) -> FloatArray:
def check_transition_probs(self, transitions: FloatArray, /) -> FloatArray:
"""Validate a transition matrix according to the topology's
restrictions.

Expand Down Expand Up @@ -216,7 +212,7 @@ class LeftRightTopology(BaseTopology):

mode: TopologyMode = TopologyMode.LEFT_RIGHT

def uniform_transition_probs(self: LeftRightTopology) -> FloatArray:
def uniform_transition_probs(self) -> FloatArray:
"""Set the transition matrix as uniform (equal probability of
transitioning to all other possible states from each state)
corresponding to the topology.
Expand All @@ -233,7 +229,7 @@ def uniform_transition_probs(self: LeftRightTopology) -> FloatArray:
lower_ones = np.tril(np.ones(self.n_states), k=-1)
return upper_ones / (upper_divisors + lower_ones)

def random_transition_probs(self: LeftRightTopology) -> FloatArray:
def random_transition_probs(self) -> FloatArray:
"""Set the transition matrix as random (random probability of
transitioning to all other possible states from each state) by
sampling probabilities from a Dirichlet distribution, according
Expand All @@ -249,9 +245,7 @@ def random_transition_probs(self: LeftRightTopology) -> FloatArray:
row[i:] = self.random_state.dirichlet(np.ones(self.n_states - i))
return transitions

def check_transition_probs(
self: LeftRightTopology, transitions: FloatArray, /
) -> FloatArray:
def check_transition_probs(self, transitions: FloatArray, /) -> FloatArray:
"""Validate a transition matrix according to the topology's
restrictions.

Expand Down Expand Up @@ -281,7 +275,7 @@ class LinearTopology(LeftRightTopology):

mode: TopologyMode = TopologyMode.LINEAR

def uniform_transition_probs(self: LinearTopology) -> FloatArray:
def uniform_transition_probs(self) -> FloatArray:
"""Set the transition matrix as uniform (equal probability of
transitioning to all other possible states from each state)
corresponding to the topology.
Expand All @@ -297,7 +291,7 @@ def uniform_transition_probs(self: LinearTopology) -> FloatArray:
row[i : (i + size)] = np.ones(size) / size
return transitions

def random_transition_probs(self: LinearTopology) -> FloatArray:
def random_transition_probs(self) -> FloatArray:
"""Set the transition matrix as random (random probability of
transitioning to all other possible states from each state) by
sampling probabilities from a Dirichlet distribution, according to the
Expand All @@ -314,9 +308,7 @@ def random_transition_probs(self: LinearTopology) -> FloatArray:
row[i : (i + size)] = self.random_state.dirichlet(np.ones(size))
return transitions

def check_transition_probs(
self: LinearTopology, transitions: FloatArray, /
) -> FloatArray:
def check_transition_probs(self, transitions: FloatArray, /) -> FloatArray:
"""Validate a transition matrix according to the topology's
restrictions.

Expand Down
2 changes: 1 addition & 1 deletion sequentia/_internal/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import numpy.typing as npt

__all__ = ["FloatArray", "IntArray", "Array"]
__all__ = ["Array", "FloatArray", "IntArray"]

FloatArray = npt.NDArray[np.float64]
IntArray = npt.NDArray[np.int64]
Expand Down
20 changes: 10 additions & 10 deletions sequentia/_internal/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from sequentia._internal._typing import Array, FloatArray, IntArray

__all__ = [
"check_random_state",
"check_is_fitted",
"requires_fit",
"check_classes",
"check_X",
"check_X_lengths",
"check_y",
"check_weighting",
"check_classes",
"check_is_fitted",
"check_random_state",
"check_use_c",
"check_weighting",
"check_y",
"requires_fit",
]


Expand Down Expand Up @@ -60,7 +60,7 @@ def check_is_fitted(

def requires_fit(function: t.Callable) -> t.Callable:
@functools.wraps(function)
def wrapper(self: t.Self, *args: t.Any, **kwargs: t.Any) -> t.Any:
def wrapper(self, *args: t.Any, **kwargs: t.Any) -> t.Any: # noqa: ANN001
check_is_fitted(self)
return function(self, *args, **kwargs)

Expand Down Expand Up @@ -106,14 +106,14 @@ def check_X(
if not isinstance(X, np.ndarray):
try:
X = np.array(X).astype(dtype)
except Exception as e: # noqa: BLE001
except Exception as e:
type_ = type(X).__name__
msg = f"Expected value to be a numpy.ndarray, got {type_!r}"
raise TypeError(msg) from e
if (dtype_ := X.dtype) != dtype:
try:
X = X.astype(dtype)
except Exception as e: # noqa: BLE001
except Exception as e:
msg = f"Expected array to have dtype {dtype}, got {dtype_}"
raise TypeError(msg) from e
if (ndim_ := X.ndim) != 2:
Expand Down Expand Up @@ -214,7 +214,7 @@ def check_weighting(
if x.shape != weights.shape:
msg = "Weights should have the same shape as inputs"
raise ValueError(msg) # noqa: TRY301
except Exception as e: # noqa: BLE001
except Exception as e:
msg = "Invalid weighting function"
raise ValueError(msg) from e

Expand Down
2 changes: 1 addition & 1 deletion sequentia/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from sequentia.datasets.digits import load_digits
from sequentia.datasets.gene_families import load_gene_families

__all__ = ["data", "load_digits", "load_gene_families", "SequentialDataset"]
__all__ = ["SequentialDataset", "data", "load_digits", "load_gene_families"]
Loading
Loading