6
6
from deepecho .sequences import assemble_sequences
7
7
8
8
9
- class DeepEcho () :
9
+ class DeepEcho :
10
10
"""The base class for DeepEcho models."""
11
11
12
12
_verbose = True
@@ -28,14 +28,20 @@ def _validate(sequences, context_types, data_types):
28
28
data_types:
29
29
See `fit`.
30
30
"""
31
- dtypes = set (['continuous' , 'categorical' , 'ordinal' , 'count' , 'datetime' ])
31
+ dtypes = set ([
32
+ "continuous" ,
33
+ "categorical" ,
34
+ "ordinal" ,
35
+ "count" ,
36
+ "datetime" ,
37
+ ])
32
38
assert all (dtype in dtypes for dtype in context_types )
33
39
assert all (dtype in dtypes for dtype in data_types )
34
40
35
41
for sequence in sequences :
36
- assert len (sequence [' context' ]) == len (context_types )
37
- assert len (sequence [' data' ]) == len (data_types )
38
- lengths = [len (x ) for x in sequence [' data' ]]
42
+ assert len (sequence [" context" ]) == len (context_types )
43
+ assert len (sequence [" data" ]) == len (data_types )
44
+ lengths = [len (x ) for x in sequence [" data" ]]
39
45
assert len (set (lengths )) == 1
40
46
41
47
def fit_sequences (self , sequences , context_types , data_types ):
@@ -87,20 +93,29 @@ def _get_data_types(data, data_types, columns):
87
93
else :
88
94
dtype = data [column ].dtype
89
95
kind = dtype .kind
90
- if kind in ' fiud' :
91
- dtypes_list .append (' continuous' )
92
- elif kind in ' OSUb' :
93
- dtypes_list .append (' categorical' )
94
- elif kind == 'M' :
95
- dtypes_list .append (' datetime' )
96
+ if kind in " fiud" :
97
+ dtypes_list .append (" continuous" )
98
+ elif kind in " OSUb" :
99
+ dtypes_list .append (" categorical" )
100
+ elif kind == "M" :
101
+ dtypes_list .append (" datetime" )
96
102
else :
97
- error = f'Unsupported data_type for column { column } : { dtype } '
103
+ error = (
104
+ f"Unsupported data_type for column { column } : { dtype } "
105
+ )
98
106
raise ValueError (error )
99
107
100
108
return dtypes_list
101
109
102
- def fit (self , data , entity_columns = None , context_columns = None ,
103
- data_types = None , segment_size = None , sequence_index = None ):
110
+ def fit (
111
+ self ,
112
+ data ,
113
+ entity_columns = None ,
114
+ context_columns = None ,
115
+ data_types = None ,
116
+ segment_size = None ,
117
+ sequence_index = None ,
118
+ ):
104
119
"""Fit the model to a dataframe containing time series data.
105
120
106
121
Args:
@@ -131,17 +146,19 @@ def fit(self, data, entity_columns=None, context_columns=None,
131
146
such as integer values or datetimes.
132
147
"""
133
148
if not entity_columns and segment_size is None :
134
- raise TypeError ('If the data has no `entity_columns`, `segment_size` must be given.' )
149
+ raise TypeError (
150
+ "If the data has no `entity_columns`, `segment_size` must be given."
151
+ )
135
152
if segment_size is not None and not isinstance (segment_size , int ):
136
153
if sequence_index is None :
137
154
raise TypeError (
138
- ' `segment_size` must be of type `int` if '
139
- ' no `sequence_index` is given.'
155
+ " `segment_size` must be of type `int` if "
156
+ " no `sequence_index` is given."
140
157
)
141
- if data [sequence_index ].dtype .kind != 'M' :
158
+ if data [sequence_index ].dtype .kind != "M" :
142
159
raise TypeError (
143
- ' `segment_size` must be of type `int` if '
144
- ' `sequence_index` is not a `datetime` column.'
160
+ " `segment_size` must be of type `int` if "
161
+ " `sequence_index` is not a `datetime` column."
145
162
)
146
163
147
164
segment_size = pd .to_timedelta (segment_size )
@@ -159,9 +176,16 @@ def fit(self, data, entity_columns=None, context_columns=None,
159
176
self ._data_columns .remove (sequence_index )
160
177
161
178
data_types = self ._get_data_types (data , data_types , self ._data_columns )
162
- context_types = self ._get_data_types (data , data_types , self ._context_columns )
179
+ context_types = self ._get_data_types (
180
+ data , data_types , self ._context_columns
181
+ )
163
182
sequences = assemble_sequences (
164
- data , self ._entity_columns , self ._context_columns , segment_size , sequence_index )
183
+ data ,
184
+ self ._entity_columns ,
185
+ self ._context_columns ,
186
+ segment_size ,
187
+ sequence_index ,
188
+ )
165
189
166
190
# Validate and fit
167
191
self ._validate (sequences , context_types , data_types )
@@ -212,7 +236,9 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
212
236
"""
213
237
if context is None :
214
238
if num_entities is None :
215
- raise TypeError ('Either context or num_entities must be not None' )
239
+ raise TypeError (
240
+ "Either context or num_entities must be not None"
241
+ )
216
242
217
243
context = self ._context_values .sample (num_entities , replace = True )
218
244
context = context .reset_index (drop = True )
@@ -242,7 +268,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
242
268
# Reformat as a DataFrame
243
269
group = pd .DataFrame (
244
270
dict (zip (self ._data_columns , sequence )),
245
- columns = self ._data_columns
271
+ columns = self ._data_columns ,
246
272
)
247
273
group [self ._entity_columns ] = entity_values
248
274
for column , value in zip (self ._context_columns , context_values ):
0 commit comments