Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff for linting, remove flake8, remove isort, remove pylint #93

Merged
merged 5 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a

.PHONY: lint-deepecho
lint-deepecho: ## check style with flake8 and isort
flake8 deepecho
isort -c --recursive deepecho
pylint deepecho --rcfile=setup.cfg

.PHONY: lint-tests
lint-tests: ## check style with flake8 and isort
Expand All @@ -92,17 +89,15 @@ lint-tests: ## check style with flake8 and isort

.PHONY: lint
lint: ## Run all code style checks
invoke lint
ruff check .
ruff format . --check

.PHONY: fix-lint
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
autopep8 --in-place --recursive --aggressive deepecho tests
isort --apply --atomic --recursive deepecho tests

fix-lint: ## fix lint issues using ruff
ruff check --fix .
ruff format .

# TEST TARGETS

.PHONY: test-unit
test-unit: ## run unit tests using pytest
invoke unit
Expand Down
14 changes: 7 additions & 7 deletions deepecho/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Top-level package for DeepEcho."""

__author__ = 'DataCebo, Inc.'
__email__ = '[email protected]'
__version__ = '0.5.1.dev0'
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__author__ = "DataCebo, Inc."
gsheni marked this conversation as resolved.
Show resolved Hide resolved
__email__ = "[email protected]"
__version__ = "0.5.1.dev0"
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from deepecho.demo import load_demo
from deepecho.models.basic_gan import BasicGANModel
from deepecho.models.par import PARModel

__all__ = [
'load_demo',
'BasicGANModel',
'PARModel',
"load_demo",
"BasicGANModel",
"PARModel",
]
6 changes: 4 additions & 2 deletions deepecho/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import pandas as pd

_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')
_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")


def load_demo():
"""Load the demo DataFrame."""
return pd.read_csv(os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date'])
return pd.read_csv(
os.path.join(_DATA_PATH, "demo.csv"), parse_dates=["date"]
)
2 changes: 1 addition & 1 deletion deepecho/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from deepecho.models.basic_gan import BasicGANModel
from deepecho.models.par import PARModel

__all__ = ['PARModel', 'BasicGANModel']
__all__ = ["PARModel", "BasicGANModel"]
74 changes: 50 additions & 24 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from deepecho.sequences import assemble_sequences


class DeepEcho():
class DeepEcho:
"""The base class for DeepEcho models."""

_verbose = True
Expand All @@ -28,14 +28,20 @@ def _validate(sequences, context_types, data_types):
data_types:
See `fit`.
"""
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
dtypes = set([
"continuous",
"categorical",
"ordinal",
"count",
"datetime",
])
assert all(dtype in dtypes for dtype in context_types)
assert all(dtype in dtypes for dtype in data_types)

for sequence in sequences:
assert len(sequence['context']) == len(context_types)
assert len(sequence['data']) == len(data_types)
lengths = [len(x) for x in sequence['data']]
assert len(sequence["context"]) == len(context_types)
assert len(sequence["data"]) == len(data_types)
lengths = [len(x) for x in sequence["data"]]
assert len(set(lengths)) == 1

def fit_sequences(self, sequences, context_types, data_types):
Expand Down Expand Up @@ -87,20 +93,29 @@ def _get_data_types(data, data_types, columns):
else:
dtype = data[column].dtype
kind = dtype.kind
if kind in 'fiud':
dtypes_list.append('continuous')
elif kind in 'OSUb':
dtypes_list.append('categorical')
elif kind == 'M':
dtypes_list.append('datetime')
if kind in "fiud":
dtypes_list.append("continuous")
elif kind in "OSUb":
dtypes_list.append("categorical")
elif kind == "M":
dtypes_list.append("datetime")
else:
error = f'Unsupported data_type for column {column}: {dtype}'
error = (
f"Unsupported data_type for column {column}: {dtype}"
)
raise ValueError(error)

return dtypes_list

def fit(self, data, entity_columns=None, context_columns=None,
data_types=None, segment_size=None, sequence_index=None):
def fit(
self,
data,
entity_columns=None,
context_columns=None,
data_types=None,
segment_size=None,
sequence_index=None,
):
"""Fit the model to a dataframe containing time series data.

Args:
Expand Down Expand Up @@ -131,17 +146,19 @@ def fit(self, data, entity_columns=None, context_columns=None,
such as integer values or datetimes.
"""
if not entity_columns and segment_size is None:
raise TypeError('If the data has no `entity_columns`, `segment_size` must be given.')
raise TypeError(
"If the data has no `entity_columns`, `segment_size` must be given."
)
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
'`segment_size` must be of type `int` if '
'no `sequence_index` is given.'
"`segment_size` must be of type `int` if "
"no `sequence_index` is given."
)
if data[sequence_index].dtype.kind != 'M':
if data[sequence_index].dtype.kind != "M":
raise TypeError(
'`segment_size` must be of type `int` if '
'`sequence_index` is not a `datetime` column.'
"`segment_size` must be of type `int` if "
"`sequence_index` is not a `datetime` column."
)

segment_size = pd.to_timedelta(segment_size)
Expand All @@ -159,9 +176,16 @@ def fit(self, data, entity_columns=None, context_columns=None,
self._data_columns.remove(sequence_index)

data_types = self._get_data_types(data, data_types, self._data_columns)
context_types = self._get_data_types(data, data_types, self._context_columns)
context_types = self._get_data_types(
data, data_types, self._context_columns
)
sequences = assemble_sequences(
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
data,
self._entity_columns,
self._context_columns,
segment_size,
sequence_index,
)

# Validate and fit
self._validate(sequences, context_types, data_types)
Expand Down Expand Up @@ -212,7 +236,9 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
"""
if context is None:
if num_entities is None:
raise TypeError('Either context or num_entities must be not None')
raise TypeError(
"Either context or num_entities must be not None"
)

context = self._context_values.sample(num_entities, replace=True)
context = context.reset_index(drop=True)
Expand Down Expand Up @@ -242,7 +268,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
# Reformat as a DataFrame
group = pd.DataFrame(
dict(zip(self._data_columns, sequence)),
columns=self._data_columns
columns=self._data_columns,
)
group[self._entity_columns] = entity_values
for column, value in zip(self._context_columns, context_values):
Expand Down
Loading
Loading