Skip to content

Commit 93323ed

Browse files
wendili-csyou-n-g
authored andcommitted
Add TFT benchmark
1 parent c897eca commit 93323ed

15 files changed

+3971
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""Default data formatting functions for experiments.
18+
19+
For new datasets, inherit form GenericDataFormatter and implement
20+
all abstract functions.
21+
22+
These dataset-specific methods:
23+
1) Define the column and input types for tabular dataframes used by model
24+
2) Perform the necessary input feature engineering & normalisation steps
25+
3) Reverts the normalisation for predictions
26+
4) Are responsible for train, validation and test splits
27+
28+
29+
"""
30+
31+
import abc
32+
import enum
33+
34+
35+
# Type defintions
36+
class DataTypes(enum.IntEnum):
37+
"""Defines numerical types of each column."""
38+
REAL_VALUED = 0
39+
CATEGORICAL = 1
40+
DATE = 2
41+
42+
43+
class InputTypes(enum.IntEnum):
44+
"""Defines input types of each column."""
45+
TARGET = 0
46+
OBSERVED_INPUT = 1
47+
KNOWN_INPUT = 2
48+
STATIC_INPUT = 3
49+
ID = 4 # Single column used as an entity identifier
50+
TIME = 5 # Single column exclusively used as a time index
51+
52+
53+
class GenericDataFormatter(abc.ABC):
54+
"""Abstract base class for all data formatters.
55+
56+
User can implement the abstract methods below to perform dataset-specific
57+
manipulations.
58+
59+
"""
60+
61+
@abc.abstractmethod
62+
def set_scalers(self, df):
63+
"""Calibrates scalers using the data supplied."""
64+
raise NotImplementedError()
65+
66+
@abc.abstractmethod
67+
def transform_inputs(self, df):
68+
"""Performs feature transformation."""
69+
raise NotImplementedError()
70+
71+
@abc.abstractmethod
72+
def format_predictions(self, df):
73+
"""Reverts any normalisation to give predictions in original scale."""
74+
raise NotImplementedError()
75+
76+
@abc.abstractmethod
77+
def split_data(self, df):
78+
"""Performs the default train, validation and test splits."""
79+
raise NotImplementedError()
80+
81+
@property
82+
@abc.abstractmethod
83+
def _column_definition(self):
84+
"""Defines order, input type and data type of each column."""
85+
raise NotImplementedError()
86+
87+
@abc.abstractmethod
88+
def get_fixed_params(self):
89+
"""Defines the fixed parameters used by the model for training.
90+
91+
Requires the following keys:
92+
'total_time_steps': Defines the total number of time steps used by TFT
93+
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
94+
'num_epochs': Maximum number of epochs for training
95+
'early_stopping_patience': Early stopping param for keras
96+
'multiprocessing_workers': # of cpus for data processing
97+
98+
99+
Returns:
100+
A dictionary of fixed parameters, e.g.:
101+
102+
fixed_params = {
103+
'total_time_steps': 252 + 5,
104+
'num_encoder_steps': 252,
105+
'num_epochs': 100,
106+
'early_stopping_patience': 5,
107+
'multiprocessing_workers': 5,
108+
}
109+
"""
110+
raise NotImplementedError
111+
112+
# Shared functions across data-formatters
113+
@property
114+
def num_classes_per_cat_input(self):
115+
"""Returns number of categories per relevant input.
116+
117+
This is seqeuently required for keras embedding layers.
118+
"""
119+
return self._num_classes_per_cat_input
120+
121+
def get_num_samples_for_calibration(self):
122+
"""Gets the default number of training and validation samples.
123+
124+
Use to sub-sample the data for network calibration and a value of -1 uses
125+
all available samples.
126+
127+
Returns:
128+
Tuple of (training samples, validation samples)
129+
"""
130+
return -1, -1
131+
132+
def get_column_definition(self):
133+
""""Returns formatted column definition in order expected by the TFT."""
134+
135+
column_definition = self._column_definition
136+
137+
# Sanity checks first.
138+
# Ensure only one ID and time column exist
139+
def _check_single_column(input_type):
140+
141+
length = len([tup for tup in column_definition if tup[2] == input_type])
142+
143+
if length != 1:
144+
raise ValueError('Illegal number of inputs ({}) of type {}'.format(
145+
length, input_type))
146+
147+
_check_single_column(InputTypes.ID)
148+
_check_single_column(InputTypes.TIME)
149+
150+
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
151+
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
152+
real_inputs = [
153+
tup for tup in column_definition if tup[1] == DataTypes.REAL_VALUED and
154+
tup[2] not in {InputTypes.ID, InputTypes.TIME}
155+
]
156+
categorical_inputs = [
157+
tup for tup in column_definition if tup[1] == DataTypes.CATEGORICAL and
158+
tup[2] not in {InputTypes.ID, InputTypes.TIME}
159+
]
160+
161+
return identifier + time + real_inputs + categorical_inputs
162+
163+
def _get_input_columns(self):
164+
"""Returns names of all input columns."""
165+
return [
166+
tup[0]
167+
for tup in self.get_column_definition()
168+
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
169+
]
170+
171+
def _get_tft_input_indices(self):
172+
"""Returns the relevant indexes and input sizes required by TFT."""
173+
174+
# Functions
175+
def _extract_tuples_from_data_type(data_type, defn):
176+
return [
177+
tup for tup in defn if tup[1] == data_type and
178+
tup[2] not in {InputTypes.ID, InputTypes.TIME}
179+
]
180+
181+
def _get_locations(input_types, defn):
182+
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
183+
184+
# Start extraction
185+
column_definition = [
186+
tup for tup in self.get_column_definition()
187+
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
188+
]
189+
190+
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL,
191+
column_definition)
192+
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED,
193+
column_definition)
194+
195+
locations = {
196+
'input_size':
197+
len(self._get_input_columns()),
198+
'output_size':
199+
len(_get_locations({InputTypes.TARGET}, column_definition)),
200+
'category_counts':
201+
self.num_classes_per_cat_input,
202+
'input_obs_loc':
203+
_get_locations({InputTypes.TARGET}, column_definition),
204+
'static_input_loc':
205+
_get_locations({InputTypes.STATIC_INPUT}, column_definition),
206+
'known_regular_inputs':
207+
_get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
208+
real_inputs),
209+
'known_categorical_inputs':
210+
_get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
211+
categorical_inputs),
212+
}
213+
214+
return locations
215+
216+
def get_experiment_params(self):
217+
"""Returns fixed model parameters for experiments."""
218+
219+
required_keys = [
220+
'total_time_steps', 'num_encoder_steps', 'num_epochs',
221+
'early_stopping_patience', 'multiprocessing_workers'
222+
]
223+
224+
fixed_params = self.get_fixed_params()
225+
226+
for k in required_keys:
227+
if k not in fixed_params:
228+
raise ValueError('Field {}'.format(k) +
229+
' missing from fixed parameter definitions!')
230+
231+
fixed_params['column_definition'] = self.get_column_definition()
232+
233+
fixed_params.update(self._get_tft_input_indices())
234+
235+
return fixed_params

0 commit comments

Comments
 (0)