Skip to content

Commit eb94791

Browse files
committed
update docs
1 parent 1e8bb59 commit eb94791

31 files changed

+226
-48
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ MYIDEA.md
1818
**/.ipynb_checkpoints/
1919
/examples/data/crawer_data.py
2020
/reference/*
21-
/models/*
21+
/models/variables/*
22+
/data/*

README.md

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Time series prediction
22
This repo implements the common methods of time series prediction, especially deep learning methods in TensorFlow2.
3-
It's highly welcomed to contribute if you have better idea, just create a PR. If any question, feel free to open an issue.
3+
It's highly welcomed to contribute if you have any better idea, just create a PR. If any question, feel free to open an issue.
44

5+
#### Ongoing project, welcome to join
56

6-
<table style="width:100%">
7+
<table style="width:100%" align="center">
78
<tr>
89
<th>
910
<p align="center">
@@ -14,6 +15,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
1415
<p align="center">
1516
<a href="./docs/arima.md" name="introduction">intro</a>
1617
</p>
18+
</th>
19+
<th>
1720
<p align="center">
1821
<a href="./deepts/models/arima.py" name="code">code</a>
1922
</p>
@@ -28,7 +31,9 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
2831
<th>
2932
<p align="center">
3033
<a href="./docs/tree.md" name="introduction">intro</a>
31-
</p>
34+
</p>
35+
</th>
36+
<th>
3237
<p align="center">
3338
<a href="./deepts/models/tree.py" name="code">code</a>
3439
</p>
@@ -44,6 +49,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
4449
<p align="center">
4550
<a href="./docs/rnn.md" name="introduction">intro</a>
4651
</p>
52+
</th>
53+
<th>
4754
<p align="center">
4855
<a href="./deepts/models/seq2seq.py" name="code">code</a>
4956
</p>
@@ -59,6 +66,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
5966
<p align="center">
6067
<a href="./docs/wavenet.md" name="introduction">intro</a>
6168
</p>
69+
</th>
70+
<th>
6271
<p align="center">
6372
<a href="./deepts/models/wavenet.py" name="code">code</a>
6473
</p>
@@ -73,7 +82,9 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
7382
<th>
7483
<p align="center">
7584
<a href="./docs/transformer.md" name="introduction">intro</a>
76-
</p>
85+
</p>
86+
</th>
87+
<th>
7788
<p align="center">
7889
<a href="./deepts/models/transformer.py" name="code">code</a>
7990
</p>
@@ -89,6 +100,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
89100
<p align="center">
90101
<a href="./docs/unet.md" name="introduction">intro</a>
91102
</p>
103+
</th>
104+
<th>
92105
<p align="center">
93106
<a href="./deepts/models/unet.py" name="code">code</a>
94107
</p>
@@ -104,6 +117,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
104117
<p align="center">
105118
<a href="./docs/nbeats.md" name="introduction">intro</a>
106119
</p>
120+
</th>
121+
<th>
107122
<p align="center">
108123
<a href="./deepts/models/nbeats.py" name="code">code</a>
109124
</p>
@@ -119,6 +134,8 @@ It's highly welcomed to contribute if you have better idea, just create a PR. If
119134
<p align="center">
120135
<a href="./docs/gan.md" name="introduction">intro</a>
121136
</p>
137+
</th>
138+
<th>
122139
<p align="center">
123140
<a href="./deepts/models/gan.py" name="code">code</a>
124141
</p>
@@ -136,7 +153,7 @@ pip install -r requirements.txt
136153
```bash
137154
bash ./data/download_passenger.sh
138155
```
139-
3. Train the model, set `custom_model_params` if you want
156+
3. Train the model, set `custom_model_params` if you want, and pay attention to your own feature engineering
140157
```bash
141158
cd examples
142159
python run_train.py --use_model seq2seq
@@ -147,7 +164,9 @@ python run_test.py
147164
```
148165

149166
## Further reading
150-
https://github.com/awslabs/gluon-ts/
167+
- https://github.com/awslabs/gluon-ts/
168+
- https://github.com/Azure/DeepLearningForTimeSeriesForecasting
151169

152170
## Contributor
153171
- [LongxingTan](https://longxingtan.github.io/)
172+

deepts/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#encodeing=utf-8

deepts/layers/attention_layer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,9 @@ def build(self,input_shape):
215215
super(PositionEncoding,self).build(input_shape)
216216

217217
def get_config(self):
218-
pass
218+
return {
219+
'max_len': self.max_len
220+
}
219221

220222
def call(self,x,masking=True):
221223
E = x.get_shape().as_list()[-1] # static

deepts/layers/rnn_layer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99

1010
class RNNLayer(Layer):
1111
def __init__(self):
12-
super(RNNLayer,self).__init__()
12+
super(RNNLayer,self).__init__()

deepts/model.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from deepts.models.unet import Unet
1212
from deepts.models.nbeats import NBeatsNet
1313
from deepts.models.gan import GAN
14-
assert tf.__version__>"2.0.0"
14+
assert tf.__version__>"2.0.0", "Should you consider to use TensorFlow 2?"
1515

1616

1717
class Loss(object):
@@ -29,11 +29,11 @@ class Optimizer(object):
2929
def __init__(self,use_optimizer):
3030
self.use_optimizer=use_optimizer
3131

32-
def __call__(self,):
32+
def __call__(self,learning_rate):
3333
if self.use_optimizer == 'adam':
34-
return tf.keras.optimizers.Adam(lr=0.001)
34+
return tf.keras.optimizers.Adam(lr=learning_rate)
3535
elif self.use_optimizer == 'sgd':
36-
return tf.keras.optimizers.SGD(lr=0.001)
36+
return tf.keras.optimizers.SGD(lr=learning_rate)
3737

3838

3939
class Model(object):
@@ -66,7 +66,7 @@ def __init__(self,params, use_model, use_loss='mse',use_optimizer='adam', custom
6666
self.use_loss = use_loss
6767
self.use_optimizer = use_optimizer
6868
self.loss_fn = Loss(use_loss)()
69-
self.optimizer_fn = Optimizer(use_optimizer)()
69+
self.optimizer_fn = Optimizer(use_optimizer)(learning_rate=params['learning_rate'])
7070
self.model = tf.keras.Model(inputs, outputs, name=use_model)
7171

7272
def train(self, dataset, n_epochs, mode='eager', export_model=False):
@@ -105,25 +105,41 @@ def train_step(self, x, y):
105105

106106
def eval(self, valid_dataset):
107107
for step,(x,y) in enumerate(valid_dataset.take(-1)):
108-
metrics=self.test_step(x,y)
108+
metrics=self.dev_step(x,y)
109109
print("=> STEP %4d Metrics: %4.2f"%(step, metrics))
110110

111-
def test_step(self, x, y):
111+
def dev_step(self, x, y):
112+
'''
113+
evaluation step function
114+
:param x:
115+
:param y:
116+
:return:
117+
'''
112118
x=tf.cast(x, tf.float32)
113119
y=tf.cast(y, tf.float32)
114-
y_pred=self.model(x,training=False)
120+
try:
121+
y_pred=self.model(x,training=False)
122+
except:
123+
y_pred=self.model((x,tf.ones([tf.shape(x)[0],self.params['output_seq_length'],1],tf.float32)))
115124
metrics=self.loss_fn(y, y_pred).numpy()
116125
return metrics
117126

118127
def predict(self, x_test, model_dir, use_model='pb'):
128+
'''
129+
predict function, don't use self.model here, but saved checkpoint or pb
130+
:param x_test:
131+
:param model_dir:
132+
:param use_model:
133+
:return:
134+
'''
119135
if use_model=='pb':
120136
print('Load saved pb model ...')
121137
model=tf.saved_model.load(model_dir)
122138
else:
123139
print('Load checkpoint model ...')
124140
model=self.model.load_weights(model_dir)
125141

126-
y_pred=model(tf.constant(x_test),True,None) # To be clarified
142+
y_pred=model(x_test,True,None) # To be clarified, not sure why additional args are necessary here
127143
return y_pred
128144

129145
def export_model(self):

deepts/models/tft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# @author: Longxing Tan, [email protected]
3-
# @date: 2020-05
3+
# @date: 2020-06
44
# paper: https://arxiv.org/pdf/1912.09363v1.pdf
55
# other implementations: https://github.com/google-research/google-research/blob/master/tft/libs/tft_model.py
66

deepts/models/transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @date: 2020-01
44
# paper:
55
# other implementations: https://github.com/maxjcohen/transformer
6+
# https://github.com/Trigram19/m5-python-starter
67

78

89
import tensorflow as tf

deepts/models/wavenet.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tensorflow as tf
1010
from tensorflow.keras.layers import Dense
1111
from deepts.layers.wavenet_layer import Dense3D, ConvTime
12-
#tf.config.experimental_run_functions_eagerly(True) # ??
1312

1413

1514
params={

docs/arima.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
## Introduction
5-
ARIMA is a short for "Autoregressive Integrated Moving Average model", it's a traditional time-series-prediction model. Basically it's a linear model combined the auto regression model and moving average model.
5+
ARIMA is a short for "Auto-regressive Integrated Moving Average model", it's a traditional time-series-prediction model. Basically it's a linear model combined the auto regression model and moving average model.
66
- Auto regression model is a linear regression model using the history data as its feature. The important hyper parameter is how many days from history are used.
77
- Moving average model is a linear regression model using the history residual error as its feature. The important hyper parameter is also history data length.
88
- arima

docs/feature.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Feature in Time series
2+
3+
## Introduction
4+
I learned a lot from the data competition especially Kaggle about feature engineering in time series. They are so good and creative at it.
5+
6+
7+
## Auto-regression
8+
9+
## Statistical feature
10+
11+
## Categorical feature
12+
13+
## Embedding feature
14+
- https://zhuanlan.zhihu.com/p/144030067
15+
- https://zhuanlan.zhihu.com/p/142681935
16+
17+

docs/gan.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# GAN
22

33
## Introduction
4-
4+
Generative Adversarial Network(GAN)
55

66
## Performance
77

docs/rnn.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# RNN
22

33
## Introduction
4+
RNN and its modification LSTM, GRU are good at modelling the sequence data. So it's natural to use them in time series application
45

56

67
## Performance
78

9+
810
## Further reading

docs/smoothing.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Smoothing in Time series
2+
3+
## Introduction
4+
Smoothing can help reduce some noise.
5+
6+
## Moving average
7+
8+
## Exponential smoothing
9+
10+
## Filter
11+
12+
13+
## Auto-encoder

docs/transformer.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Transformer
22

33
## Introduction
4-
Transformer is introduced in [Attention is all your need]()
4+
Transformer is introduced in [Attention is all your need](https://arxiv.org/abs/1706.03762), it has become the most popular NLP model.
55

66
## Performance
77

8+
89
## Further reading
910

docs/tree.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
# GBDT
22

33
## Introduction of GBDT
4+
Gradient boosting decision tree is also the promising solutions for time series issues.
45

56
## Introduction of XGBoost
7+
XGBoost is introduced in [XGBoost: A Scalable Tree Boosting System](https://arxiv.org/abs/1603.02754)
68

79
## Introduction of LightGBM
10+
LightGBM is introduces in [LightGBM: A Highly Efficient Gradient Boosting Decision Tree](http://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf)
811

912
## Performance
13+
GBDT could also be tuned into SOTA model in time series. I read some implementations of the competition to let me so sure that I'm not so good at tuning the parameters.
14+
So I believe the performance here could be further optimized.
1015

11-
## Further reading
16+
## Further reading
17+
- https://www.kaggle.com/pureheart/1st-place-lgb-model-public-0-470-private-0-502
18+
- https://www.kaggle.com/plantsgo/solution-public-0-471-private-0-505

docs/unet.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
## Introduction
44
U-Net is introduced in [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597)
5-
used for image segmentation.
5+
used for image segmentation, almost the most popular model for image segmentation.
6+
7+
We can tune the model a little for time series prediction.
8+
69

710
## Performance
811

docs/validation.md

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Validation in Time series
2+
3+
Validation is difficult in time series.
4+
5+
Basically, the data can be split by time, or by examples.
6+
7+
8+
## Examples in Kaggle
9+
10+
There are some popular time-series-prediction competitions in Kaggle and M-series.
11+
12+
13+
14+
15+

docs/wavenet.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# Wavenet
22

33
## Introduction
4-
Wavenet is introduced in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499) by DeepMind, first used for audio generation. The main components use the causal dilated convolutional neutral network. The kernel of CNN layer share the same weights, so it can also be used to percept the seasonality of time series issue.
4+
Wavenet is introduced in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499) by DeepMind, first used in audio generation. Its main components use the causal dilated convolutional neutral network. The kernel of CNN layer share the same weights, so it can also be used to percept the seasonality of time series issue.
55

66
The dilated causal convolutional layer
77
![wavenet](https://github.com/LongxingTan/Time-series-prediction/blob/master/docs/assets/wavenet.gif)
88

9+
It's become popular in time series application since sjv open-sourced his repo [web-traffic-forecasting](https://github.com/sjvasquez/web-traffic-forecasting)
10+
911
## Some detail
10-
### casual dilated convolutional neutral network
12+
#### casual dilated convolutional neutral network
1113
Casual: make sure that the future information won't leak
1214

13-
Normal convolution
14-
1515
Dilated: extend the receptive field to track the long term dependencies
1616
Implementation:
1717

18-
18+
Normal convolution:
1919

2020

2121
## Performance

0 commit comments

Comments
 (0)