Skip to content

Commit 34dda74

Browse files
committed
installed ruff and ran
1 parent ed3645b commit 34dda74

15 files changed

+786
-565
lines changed

Makefile

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a
8181

8282
.PHONY: lint-deepecho
8383
lint-deepecho: ## check style with flake8 and isort
84-
flake8 deepecho
85-
isort -c --recursive deepecho
86-
pylint deepecho --rcfile=setup.cfg
8784

8885
.PHONY: lint-tests
8986
lint-tests: ## check style with flake8 and isort
@@ -92,17 +89,15 @@ lint-tests: ## check style with flake8 and isort
9289

9390
.PHONY: lint
9491
lint: ## Run all code style checks
95-
invoke lint
92+
ruff check .
93+
ruff format . --check
9694

9795
.PHONY: fix-lint
98-
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
99-
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
100-
autopep8 --in-place --recursive --aggressive deepecho tests
101-
isort --apply --atomic --recursive deepecho tests
102-
96+
fix-lint: ## fix lint issues using ruff
97+
ruff check --fix .
98+
ruff format .
10399

104100
# TEST TARGETS
105-
106101
.PHONY: test-unit
107102
test-unit: ## run unit tests using pytest
108103
invoke unit

deepecho/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""Top-level package for DeepEcho."""
22

3-
__author__ = 'DataCebo, Inc.'
4-
__email__ = '[email protected]'
5-
__version__ = '0.5.1.dev0'
6-
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
3+
__author__ = "DataCebo, Inc."
4+
__email__ = "[email protected]"
5+
__version__ = "0.5.1.dev0"
6+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
77

88
from deepecho.demo import load_demo
99
from deepecho.models.basic_gan import BasicGANModel
1010
from deepecho.models.par import PARModel
1111

1212
__all__ = [
13-
'load_demo',
14-
'BasicGANModel',
15-
'PARModel',
13+
"load_demo",
14+
"BasicGANModel",
15+
"PARModel",
1616
]

deepecho/demo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import pandas as pd
66

7-
_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')
7+
_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
88

99

1010
def load_demo():
1111
"""Load the demo DataFrame."""
12-
return pd.read_csv(os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date'])
12+
return pd.read_csv(
13+
os.path.join(_DATA_PATH, "demo.csv"), parse_dates=["date"]
14+
)

deepecho/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from deepecho.models.basic_gan import BasicGANModel
44
from deepecho.models.par import PARModel
55

6-
__all__ = ['PARModel', 'BasicGANModel']
6+
__all__ = ["PARModel", "BasicGANModel"]

deepecho/models/base.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from deepecho.sequences import assemble_sequences
77

88

9-
class DeepEcho():
9+
class DeepEcho:
1010
"""The base class for DeepEcho models."""
1111

1212
_verbose = True
@@ -28,14 +28,20 @@ def _validate(sequences, context_types, data_types):
2828
data_types:
2929
See `fit`.
3030
"""
31-
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
31+
dtypes = set([
32+
"continuous",
33+
"categorical",
34+
"ordinal",
35+
"count",
36+
"datetime",
37+
])
3238
assert all(dtype in dtypes for dtype in context_types)
3339
assert all(dtype in dtypes for dtype in data_types)
3440

3541
for sequence in sequences:
36-
assert len(sequence['context']) == len(context_types)
37-
assert len(sequence['data']) == len(data_types)
38-
lengths = [len(x) for x in sequence['data']]
42+
assert len(sequence["context"]) == len(context_types)
43+
assert len(sequence["data"]) == len(data_types)
44+
lengths = [len(x) for x in sequence["data"]]
3945
assert len(set(lengths)) == 1
4046

4147
def fit_sequences(self, sequences, context_types, data_types):
@@ -87,20 +93,29 @@ def _get_data_types(data, data_types, columns):
8793
else:
8894
dtype = data[column].dtype
8995
kind = dtype.kind
90-
if kind in 'fiud':
91-
dtypes_list.append('continuous')
92-
elif kind in 'OSUb':
93-
dtypes_list.append('categorical')
94-
elif kind == 'M':
95-
dtypes_list.append('datetime')
96+
if kind in "fiud":
97+
dtypes_list.append("continuous")
98+
elif kind in "OSUb":
99+
dtypes_list.append("categorical")
100+
elif kind == "M":
101+
dtypes_list.append("datetime")
96102
else:
97-
error = f'Unsupported data_type for column {column}: {dtype}'
103+
error = (
104+
f"Unsupported data_type for column {column}: {dtype}"
105+
)
98106
raise ValueError(error)
99107

100108
return dtypes_list
101109

102-
def fit(self, data, entity_columns=None, context_columns=None,
103-
data_types=None, segment_size=None, sequence_index=None):
110+
def fit(
111+
self,
112+
data,
113+
entity_columns=None,
114+
context_columns=None,
115+
data_types=None,
116+
segment_size=None,
117+
sequence_index=None,
118+
):
104119
"""Fit the model to a dataframe containing time series data.
105120
106121
Args:
@@ -131,17 +146,19 @@ def fit(self, data, entity_columns=None, context_columns=None,
131146
such as integer values or datetimes.
132147
"""
133148
if not entity_columns and segment_size is None:
134-
raise TypeError('If the data has no `entity_columns`, `segment_size` must be given.')
149+
raise TypeError(
150+
"If the data has no `entity_columns`, `segment_size` must be given."
151+
)
135152
if segment_size is not None and not isinstance(segment_size, int):
136153
if sequence_index is None:
137154
raise TypeError(
138-
'`segment_size` must be of type `int` if '
139-
'no `sequence_index` is given.'
155+
"`segment_size` must be of type `int` if "
156+
"no `sequence_index` is given."
140157
)
141-
if data[sequence_index].dtype.kind != 'M':
158+
if data[sequence_index].dtype.kind != "M":
142159
raise TypeError(
143-
'`segment_size` must be of type `int` if '
144-
'`sequence_index` is not a `datetime` column.'
160+
"`segment_size` must be of type `int` if "
161+
"`sequence_index` is not a `datetime` column."
145162
)
146163

147164
segment_size = pd.to_timedelta(segment_size)
@@ -159,9 +176,16 @@ def fit(self, data, entity_columns=None, context_columns=None,
159176
self._data_columns.remove(sequence_index)
160177

161178
data_types = self._get_data_types(data, data_types, self._data_columns)
162-
context_types = self._get_data_types(data, data_types, self._context_columns)
179+
context_types = self._get_data_types(
180+
data, data_types, self._context_columns
181+
)
163182
sequences = assemble_sequences(
164-
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
183+
data,
184+
self._entity_columns,
185+
self._context_columns,
186+
segment_size,
187+
sequence_index,
188+
)
165189

166190
# Validate and fit
167191
self._validate(sequences, context_types, data_types)
@@ -212,7 +236,9 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
212236
"""
213237
if context is None:
214238
if num_entities is None:
215-
raise TypeError('Either context or num_entities must be not None')
239+
raise TypeError(
240+
"Either context or num_entities must be not None"
241+
)
216242

217243
context = self._context_values.sample(num_entities, replace=True)
218244
context = context.reset_index(drop=True)
@@ -242,7 +268,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
242268
# Reformat as a DataFrame
243269
group = pd.DataFrame(
244270
dict(zip(self._data_columns, sequence)),
245-
columns=self._data_columns
271+
columns=self._data_columns,
246272
)
247273
group[self._entity_columns] = entity_values
248274
for column, value in zip(self._context_columns, context_values):

0 commit comments

Comments
 (0)