Skip to content

Commit eb94791

Browse files
committed
update docs
1 parent 1e8bb59 commit eb94791

31 files changed

+226
-48
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ MYIDEA.md
1818
**/.ipynb_checkpoints/
1919
/examples/data/crawer_data.py
2020
/reference/*
21-
/models/*
21+
/models/variables/*
22+
/data/*

README.md

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Time series prediction
22
This repo implements the common methods of time series prediction, especially deep learning methods in TensorFlow2.
3-
It's highly welcomed to contribute if you have better idea, just create a PR. If any question, feel free to open an issue.
3+
It's highly welcomed to contribute if you have any better idea, just create a PR. If any question, feel free to open an issue.
44

5+
#### Ongoing project, welcome to join
56

6-
<table style="width:100%">
7+
<table style="width:100%" align="center">
78
<tr>
89
<th>
910
<p align="center">
@@ -14,6 +15,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
1415
<p align="center">
1516
<a href="./docs/arima.md" name="introduction">intro</a>
1617
</p>
18+
</th>
19+
<th>
1720
<p align="center">
1821
<a href="./deepts/models/arima.py" name="code">code</a>
1922
</p>
@@ -28,7 +31,9 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
2831
<th>
2932
<p align="center">
3033
<a href="./docs/tree.md" name="introduction">intro</a>
31-
</p>
34+
</p>
35+
</th>
36+
<th>
3237
<p align="center">
3338
<a href="./deepts/models/tree.py" name="code">code</a>
3439
</p>
@@ -44,6 +49,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
4449
<p align="center">
4550
<a href="./docs/rnn.md" name="introduction">intro</a>
4651
</p>
52+
</th>
53+
<th>
4754
<p align="center">
4855
<a href="./deepts/models/seq2seq.py" name="code">code</a>
4956
</p>
@@ -59,6 +66,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
5966
<p align="center">
6067
<a href="./docs/wavenet.md" name="introduction">intro</a>
6168
</p>
69+
</th>
70+
<th>
6271
<p align="center">
6372
<a href="./deepts/models/wavenet.py" name="code">code</a>
6473
</p>
@@ -73,7 +82,9 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
7382
<th>
7483
<p align="center">
7584
<a href="./docs/transformer.md" name="introduction">intro</a>
76-
</p>
85+
</p>
86+
</th>
87+
<th>
7788
<p align="center">
7889
<a href="./deepts/models/transformer.py" name="code">code</a>
7990
</p>
@@ -89,6 +100,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
89100
<p align="center">
90101
<a href="./docs/unet.md" name="introduction">intro</a>
91102
</p>
103+
</th>
104+
<th>
92105
<p align="center">
93106
<a href="./deepts/models/unet.py" name="code">code</a>
94107
</p>
@@ -104,6 +117,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
104117
<p align="center">
105118
<a href="./docs/nbeats.md" name="introduction">intro</a>
106119
</p>
120+
</th>
121+
<th>
107122
<p align="center">
108123
<a href="./deepts/models/nbeats.py" name="code">code</a>
109124
</p>
@@ -119,6 +134,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
119134
<p align="center">
120135
<a href="./docs/gan.md" name="introduction">intro</a>
121136
</p>
137+
</th>
138+
<th>
122139
<p align="center">
123140
<a href="./deepts/models/gan.py" name="code">code</a>
124141
</p>
@@ -136,7 +153,7 @@ pip install -r requirements.txt
136153
```bash
137154
bash ./data/download_passenger.sh
138155
```
139-
3. Train the model, set `custom_model_params` if you want
156+
3. Train the model, set `custom_model_params` if you want, and pay attention to your own feature engineering
140157
```bash
141158
cd examples
142159
python run_train.py --use_model seq2seq
@@ -147,7 +164,9 @@ python run_test.py
147164
```
148165

149166
## Further reading
150-
https://github.com/awslabs/gluon-ts/
167+
- https://github.com/awslabs/gluon-ts/
168+
- https://github.com/Azure/DeepLearningForTimeSeriesForecasting
151169

152170
## Contributor
153171
- [LongxingTan](https://longxingtan.github.io/)
172+

deepts/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#encodeing=utf-8

deepts/layers/attention_layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,9 @@ def build(self,input_shape):
215215
super(PositionEncoding,self).build(input_shape)
216216

217217
def get_config(self):
218-
pass
218+
return {
219+
'max_len': self.max_len
220+
}
219221

220222
def call(self,x,masking=True):
221223
E = x.get_shape().as_list()[-1] # static

deepts/layers/rnn_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99

1010
class RNNLayer(Layer):
1111
def __init__(self):
12-
super(RNNLayer,self).__init__()
12+
super(RNNLayer,self).__init__()

deepts/model.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from deepts.models.unet import Unet
1212
from deepts.models.nbeats import NBeatsNet
1313
from deepts.models.gan import GAN
14-
assert tf.__version__>"2.0.0"
14+
assert tf.__version__>"2.0.0", "Should you consider to use TensorFlow 2?"
1515

1616

1717
class Loss(object):
@@ -29,11 +29,11 @@ class Optimizer(object):
2929
def __init__(self,use_optimizer):
3030
self.use_optimizer=use_optimizer
3131

32-
def __call__(self,):
32+
def __call__(self,learning_rate):
3333
if self.use_optimizer == 'adam':
34-
return tf.keras.optimizers.Adam(lr=0.001)
34+
return tf.keras.optimizers.Adam(lr=learning_rate)
3535
elif self.use_optimizer == 'sgd':
36-
return tf.keras.optimizers.SGD(lr=0.001)
36+
return tf.keras.optimizers.SGD(lr=learning_rate)
3737

3838

3939
class Model(object):
@@ -66,7 +66,7 @@ def __init__(self,params, use_model, use_loss='mse',use_optimizer='adam', custom
6666
self.use_loss = use_loss
6767
self.use_optimizer = use_optimizer
6868
self.loss_fn = Loss(use_loss)()
69-
self.optimizer_fn = Optimizer(use_optimizer)()
69+
self.optimizer_fn = Optimizer(use_optimizer)(learning_rate=params['learning_rate'])
7070
self.model = tf.keras.Model(inputs, outputs, name=use_model)
7171

7272
def train(self, dataset, n_epochs, mode='eager', export_model=False):
@@ -105,25 +105,41 @@ def train_step(self, x, y):
105105

106106
def eval(self, valid_dataset):
107107
for step,(x,y) in enumerate(valid_dataset.take(-1)):
108-
metrics=self.test_step(x,y)
108+
metrics=self.dev_step(x,y)
109109
print("=> STEP %4d Metrics: %4.2f"%(step, metrics))
110110

111-
def test_step(self, x, y):
111+
def dev_step(self, x, y):
112+
'''
113+
evaluation step function
114+
:param x:
115+
:param y:
116+
:return:
117+
'''
112118
x=tf.cast(x, tf.float32)
113119
y=tf.cast(y, tf.float32)
114-
y_pred=self.model(x,training=False)
120+
try:
121+
y_pred=self.model(x,training=False)
122+
except:
123+
y_pred=self.model((x,tf.ones([tf.shape(x)[0],self.params['output_seq_length'],1],tf.float32)))
115124
metrics=self.loss_fn(y, y_pred).numpy()
116125
return metrics
117126

118127
def predict(self, x_test, model_dir, use_model='pb'):
128+
'''
129+
predict function, don't use self.model here, but saved checkpoint or pb
130+
:param x_test:
131+
:param model_dir:
132+
:param use_model:
133+
:return:
134+
'''
119135
if use_model=='pb':
120136
print('Load saved pb model ...')
121137
model=tf.saved_model.load(model_dir)
122138
else:
123139
print('Load checkpoint model ...')
124140
model=self.model.load_weights(model_dir)
125141

126-
y_pred=model(tf.constant(x_test),True,None) # To be clarified
142+
y_pred=model(x_test,True,None) # To be clarified, not sure why additional args are necessary here
127143
return y_pred
128144

129145
def export_model(self):

deepts/models/tft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# @author: Longxing Tan, [email protected]
3-
# @date: 2020-05
3+
# @date: 2020-06
44
# paper: https://arxiv.org/pdf/1912.09363v1.pdf
55
# other implementations: https://github.com/google-research/google-research/blob/master/tft/libs/tft_model.py
66

deepts/models/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @date: 2020-01
44
# paper:
55
# other implementations: https://github.com/maxjcohen/transformer
6+
# https://github.com/Trigram19/m5-python-starter
67

78

89
import tensorflow as tf

deepts/models/wavenet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tensorflow as tf
1010
from tensorflow.keras.layers import Dense
1111
from deepts.layers.wavenet_layer import Dense3D, ConvTime
12-
#tf.config.experimental_run_functions_eagerly(True) # ??
1312

1413

1514
params={

docs/arima.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
## Introduction
5-
ARIMA is a short for "Autoregressive Integrated Moving Average model", it's a traditional time-series-prediction model. Basically it's a linear model combined the auto regression model and moving average model.
5+
ARIMA is a short for "Auto-regressive Integrated Moving Average model", it's a traditional time-series-prediction model. Basically it's a linear model combined the auto regression model and moving average model.
66
- Auto regression model is a linear regression model using the history data as its feature. The important hyper parameter is how many days from history are used.
77
- Moving average model is a linear regression model using the history residual error as its feature. The important hyper parameter is also history data length.
88
- arima

0 commit comments

Comments
 (0)