Skip to content

Commit 4ad94da

Browse files
authored
Add ruff for linting, remove flake8, remove isort, remove pylint (#94)
1 parent ed3645b commit 4ad94da

File tree

12 files changed

+318
-243
lines changed

12 files changed

+318
-243
lines changed

Makefile

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,17 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a
7979

8080
# LINT TARGETS
8181

82-
.PHONY: lint-deepecho
83-
lint-deepecho: ## check style with flake8 and isort
84-
flake8 deepecho
85-
isort -c --recursive deepecho
86-
pylint deepecho --rcfile=setup.cfg
87-
88-
.PHONY: lint-tests
89-
lint-tests: ## check style with flake8 and isort
90-
flake8 --ignore=D tests
91-
isort -c --recursive tests
92-
9382
.PHONY: lint
94-
lint: ## Run all code style checks
95-
invoke lint
83+
lint:
84+
ruff check .
85+
ruff format . --check
9686

9787
.PHONY: fix-lint
98-
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
99-
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
100-
autopep8 --in-place --recursive --aggressive deepecho tests
101-
isort --apply --atomic --recursive deepecho tests
102-
88+
fix-lint:
89+
ruff check --fix .
90+
ruff format .
10391

10492
# TEST TARGETS
105-
10693
.PHONY: test-unit
10794
test-unit: ## run unit tests using pytest
10895
invoke unit

deepecho/models/base.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from deepecho.sequences import assemble_sequences
77

88

9-
class DeepEcho():
9+
class DeepEcho:
1010
"""The base class for DeepEcho models."""
1111

1212
_verbose = True
@@ -28,7 +28,13 @@ def _validate(sequences, context_types, data_types):
2828
data_types:
2929
See `fit`.
3030
"""
31-
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
31+
dtypes = set([
32+
'continuous',
33+
'categorical',
34+
'ordinal',
35+
'count',
36+
'datetime',
37+
])
3238
assert all(dtype in dtypes for dtype in context_types)
3339
assert all(dtype in dtypes for dtype in data_types)
3440

@@ -99,8 +105,15 @@ def _get_data_types(data, data_types, columns):
99105

100106
return dtypes_list
101107

102-
def fit(self, data, entity_columns=None, context_columns=None,
103-
data_types=None, segment_size=None, sequence_index=None):
108+
def fit(
109+
self,
110+
data,
111+
entity_columns=None,
112+
context_columns=None,
113+
data_types=None,
114+
segment_size=None,
115+
sequence_index=None,
116+
):
104117
"""Fit the model to a dataframe containing time series data.
105118
106119
Args:
@@ -135,8 +148,7 @@ def fit(self, data, entity_columns=None, context_columns=None,
135148
if segment_size is not None and not isinstance(segment_size, int):
136149
if sequence_index is None:
137150
raise TypeError(
138-
'`segment_size` must be of type `int` if '
139-
'no `sequence_index` is given.'
151+
'`segment_size` must be of type `int` if ' 'no `sequence_index` is given.'
140152
)
141153
if data[sequence_index].dtype.kind != 'M':
142154
raise TypeError(
@@ -161,7 +173,12 @@ def fit(self, data, entity_columns=None, context_columns=None,
161173
data_types = self._get_data_types(data, data_types, self._data_columns)
162174
context_types = self._get_data_types(data, data_types, self._context_columns)
163175
sequences = assemble_sequences(
164-
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
176+
data,
177+
self._entity_columns,
178+
self._context_columns,
179+
segment_size,
180+
sequence_index,
181+
)
165182

166183
# Validate and fit
167184
self._validate(sequences, context_types, data_types)
@@ -242,7 +259,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
242259
# Reformat as a DataFrame
243260
group = pd.DataFrame(
244261
dict(zip(self._data_columns, sequence)),
245-
columns=self._data_columns
262+
columns=self._data_columns,
246263
)
247264
group[self._entity_columns] = entity_values
248265
for column, value in zip(self._context_columns, context_values):

deepecho/models/basic_gan.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414

1515
def _expand_context(data, context):
16-
return torch.cat([
17-
data,
18-
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1])
19-
], dim=2)
16+
return torch.cat(
17+
[
18+
data,
19+
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1]),
20+
],
21+
dim=2,
22+
)
2023

2124

2225
class BasicGenerator(torch.nn.Module):
@@ -65,7 +68,7 @@ def forward(self, context=None, sequence_length=None):
6568
"""
6669
latent = torch.randn(
6770
size=(sequence_length, context.size(0), self.latent_size),
68-
device=self.device
71+
device=self.device,
6972
)
7073
latent = _expand_context(latent, context)
7174

@@ -150,8 +153,16 @@ class BasicGANModel(DeepEcho):
150153
_model_data_size = None
151154
_generator = None
152155

153-
def __init__(self, epochs=1024, latent_size=32, hidden_size=16,
154-
gen_lr=1e-3, dis_lr=1e-3, cuda=True, verbose=True):
156+
def __init__(
157+
self,
158+
epochs=1024,
159+
latent_size=32,
160+
hidden_size=16,
161+
gen_lr=1e-3,
162+
dis_lr=1e-3,
163+
cuda=True,
164+
verbose=True,
165+
):
155166
self._epochs = epochs
156167
self._gen_lr = gen_lr
157168
self._dis_lr = dis_lr
@@ -211,7 +222,7 @@ def _index_map(columns, types):
211222
'type': column_type,
212223
'min': np.min(values),
213224
'max': np.max(values),
214-
'indices': (dimensions, dimensions + 1)
225+
'indices': (dimensions, dimensions + 1),
215226
}
216227
dimensions += 2
217228

@@ -221,10 +232,7 @@ def _index_map(columns, types):
221232
indices[value] = dimensions
222233
dimensions += 1
223234

224-
mapping[column] = {
225-
'type': column_type,
226-
'indices': indices
227-
}
235+
mapping[column] = {'type': column_type, 'indices': indices}
228236

229237
else:
230238
raise ValueError(f'Unsupported type: {column_type}')
@@ -317,7 +325,7 @@ def _value_to_tensor(self, tensor, value, properties):
317325
self._one_hot_encode(tensor, value, properties)
318326

319327
else:
320-
raise ValueError() # Theoretically unreachable
328+
raise ValueError() # Theoretically unreachable
321329

322330
def _data_to_tensor(self, data):
323331
"""Convert the input data to the corresponding tensor.
@@ -370,7 +378,7 @@ def _tensor_to_data(self, tensor):
370378
elif column_type in ('categorical', 'ordinal'):
371379
value = self._one_hot_decode(tensor, row, properties)
372380
else:
373-
raise ValueError() # Theoretically unreachable
381+
raise ValueError() # Theoretically unreachable
374382

375383
column_data.append(value)
376384

@@ -412,7 +420,7 @@ def _truncate(self, generated):
412420
end_flag = sequence[:, self._data_size]
413421
if (end_flag == 1.0).any():
414422
cut_idx = end_flag.detach().cpu().numpy().argmax()
415-
sequence[cut_idx + 1:] = 0.0
423+
sequence[cut_idx + 1 :] = 0.0
416424

417425
def _generate(self, context, sequence_length=None):
418426
generated = self._generator(

0 commit comments

Comments
 (0)