diff --git a/examples/Tutorial_1_Quick_start.ipynb b/examples/Tutorial_1_Quick_start.ipynb
index 751fc53..4fd0c82 100644
--- a/examples/Tutorial_1_Quick_start.ipynb
+++ b/examples/Tutorial_1_Quick_start.ipynb
@@ -17,15 +17,16 @@
" - [Strategy](#toc2_1_2_6_) \n",
" - [Backtest validation of pipeline](#toc2_2_) \n",
" - [Sliding Window Validation](#toc2_3_) \n",
- " - [Working with raw time series' granularity](#toc2_4_) \n",
+ " - [Working with raw time series' granularity](#toc2_4_)\n",
+ " - [Subsampling from Dataset](#toc2_5_) \n",
"\n",
"\n",
+ " numbering=false\n",
+ " anchor=true\n",
+ " flat=false\n",
+ " minLevel=1\n",
+ " maxLevel=6\n",
+ " /vscode-jupyter-toc-config -->\n",
""
]
},
@@ -52,7 +53,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -74,7 +75,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -153,7 +154,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -177,7 +178,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -273,7 +274,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -290,7 +291,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -306,7 +307,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -352,7 +353,7 @@
},
{
"cell_type": "code",
- "execution_count": 62,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -377,7 +378,7 @@
},
{
"cell_type": "code",
- "execution_count": 63,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -424,7 +425,7 @@
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -483,7 +484,7 @@
},
{
"cell_type": "code",
- "execution_count": 65,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -493,7 +494,7 @@
},
{
"cell_type": "code",
- "execution_count": 66,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -502,24 +503,24 @@
},
{
"cell_type": "code",
- "execution_count": 67,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 2.64ms\tremaining: 2.64s\n",
- "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 458ms\tremaining: 456ms\n",
- "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 973ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 58.2ms\tremaining: 58.1s\n",
+ "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 544ms\tremaining: 542ms\n",
+ "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 988ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003367620955\n",
"bestIteration = 999\n",
"\n",
"Fold 0. Score: 0.0033676209549416128\n",
- "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.56ms\tremaining: 1.56s\n",
- "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 445ms\tremaining: 443ms\n",
- "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 912ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.03ms\tremaining: 1.03s\n",
+ "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 554ms\tremaining: 552ms\n",
+ "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 1.03s\tremaining: 0us\n",
"\n",
"bestTest = 0.003339095317\n",
"bestIteration = 999\n",
@@ -536,7 +537,7 @@
},
{
"cell_type": "code",
- "execution_count": 68,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
@@ -553,7 +554,7 @@
},
{
"cell_type": "code",
- "execution_count": 69,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -801,7 +802,7 @@
"29 9 2022-09-29 10981.801303"
]
},
- "execution_count": 69,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@@ -826,7 +827,7 @@
},
{
"cell_type": "code",
- "execution_count": 70,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -834,17 +835,17 @@
"output_type": "stream",
"text": [
"freq: Day; period: 1\n",
- "0:\tlearn: 0.9618043\ttest: 0.9656878\tbest: 0.9656878 (0)\ttotal: 2.18ms\tremaining: 2.18s\n",
- "500:\tlearn: 0.0051787\ttest: 0.0052308\tbest: 0.0052308 (500)\ttotal: 455ms\tremaining: 453ms\n",
- "999:\tlearn: 0.0030887\ttest: 0.0032962\tbest: 0.0032962 (999)\ttotal: 956ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9618043\ttest: 0.9656878\tbest: 0.9656878 (0)\ttotal: 870us\tremaining: 870ms\n",
+ "500:\tlearn: 0.0051787\ttest: 0.0052308\tbest: 0.0052308 (500)\ttotal: 532ms\tremaining: 530ms\n",
+ "999:\tlearn: 0.0030887\ttest: 0.0032962\tbest: 0.0032962 (999)\ttotal: 985ms\tremaining: 0us\n",
"\n",
"bestTest = 0.00329621851\n",
"bestIteration = 999\n",
"\n",
"Fold 0. Score: 0.0032962185095815415\n",
- "0:\tlearn: 0.9647141\ttest: 0.9623340\tbest: 0.9623340 (0)\ttotal: 1.5ms\tremaining: 1.5s\n",
- "500:\tlearn: 0.0055104\ttest: 0.0057187\tbest: 0.0057187 (500)\ttotal: 440ms\tremaining: 438ms\n",
- "999:\tlearn: 0.0033140\ttest: 0.0035696\tbest: 0.0035696 (999)\ttotal: 866ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9647141\ttest: 0.9623340\tbest: 0.9623340 (0)\ttotal: 1.11ms\tremaining: 1.11s\n",
+ "500:\tlearn: 0.0055104\ttest: 0.0057187\tbest: 0.0057187 (500)\ttotal: 427ms\tremaining: 425ms\n",
+ "999:\tlearn: 0.0033140\ttest: 0.0035696\tbest: 0.0035696 (999)\ttotal: 861ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003569589613\n",
"bestIteration = 999\n",
@@ -854,17 +855,17 @@
"Std: 0.0001\n",
"freq: Day; period: 1\n",
"freq: Day; period: 1\n",
- "0:\tlearn: 0.9635630\ttest: 0.9636501\tbest: 0.9636501 (0)\ttotal: 959us\tremaining: 958ms\n",
- "500:\tlearn: 0.0051732\ttest: 0.0053000\tbest: 0.0053000 (500)\ttotal: 446ms\tremaining: 444ms\n",
- "999:\tlearn: 0.0031204\ttest: 0.0032928\tbest: 0.0032928 (999)\ttotal: 889ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9635630\ttest: 0.9636501\tbest: 0.9636501 (0)\ttotal: 1.02ms\tremaining: 1.01s\n",
+ "500:\tlearn: 0.0051732\ttest: 0.0053000\tbest: 0.0053000 (500)\ttotal: 449ms\tremaining: 447ms\n",
+ "999:\tlearn: 0.0031204\ttest: 0.0032928\tbest: 0.0032928 (999)\ttotal: 878ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003292811776\n",
"bestIteration = 999\n",
"\n",
"Fold 0. Score: 0.0032928117763904234\n",
- "0:\tlearn: 0.9632624\ttest: 0.9635496\tbest: 0.9635496 (0)\ttotal: 2.48ms\tremaining: 2.48s\n",
- "500:\tlearn: 0.0053458\ttest: 0.0056056\tbest: 0.0056056 (500)\ttotal: 545ms\tremaining: 543ms\n",
- "999:\tlearn: 0.0032541\ttest: 0.0035628\tbest: 0.0035628 (999)\ttotal: 979ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9632624\ttest: 0.9635496\tbest: 0.9635496 (0)\ttotal: 1.02ms\tremaining: 1.02s\n",
+ "500:\tlearn: 0.0053458\ttest: 0.0056056\tbest: 0.0056056 (500)\ttotal: 432ms\tremaining: 431ms\n",
+ "999:\tlearn: 0.0032541\ttest: 0.0035628\tbest: 0.0035628 (999)\ttotal: 866ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003562794598\n",
"bestIteration = 999\n",
@@ -874,17 +875,17 @@
"Std: 0.0001\n",
"freq: Day; period: 1\n",
"freq: Day; period: 1\n",
- "0:\tlearn: 0.9672679\ttest: 0.9599529\tbest: 0.9599529 (0)\ttotal: 1.4ms\tremaining: 1.4s\n",
- "500:\tlearn: 0.0052718\ttest: 0.0054426\tbest: 0.0054426 (500)\ttotal: 447ms\tremaining: 445ms\n",
- "999:\tlearn: 0.0030843\ttest: 0.0033111\tbest: 0.0033111 (999)\ttotal: 860ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9672679\ttest: 0.9599529\tbest: 0.9599529 (0)\ttotal: 903us\tremaining: 902ms\n",
+ "500:\tlearn: 0.0052718\ttest: 0.0054426\tbest: 0.0054426 (500)\ttotal: 421ms\tremaining: 420ms\n",
+ "999:\tlearn: 0.0030843\ttest: 0.0033111\tbest: 0.0033111 (999)\ttotal: 865ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003311056772\n",
"bestIteration = 999\n",
"\n",
"Fold 0. Score: 0.003311056771697718\n",
- "0:\tlearn: 0.9591189\ttest: 0.9679140\tbest: 0.9679140 (0)\ttotal: 1.44ms\tremaining: 1.44s\n",
- "500:\tlearn: 0.0053303\ttest: 0.0056656\tbest: 0.0056656 (500)\ttotal: 429ms\tremaining: 427ms\n",
- "999:\tlearn: 0.0031389\ttest: 0.0034467\tbest: 0.0034467 (999)\ttotal: 982ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9591189\ttest: 0.9679140\tbest: 0.9679140 (0)\ttotal: 1.16ms\tremaining: 1.16s\n",
+ "500:\tlearn: 0.0053303\ttest: 0.0056656\tbest: 0.0056656 (500)\ttotal: 431ms\tremaining: 429ms\n",
+ "999:\tlearn: 0.0031389\ttest: 0.0034467\tbest: 0.0034467 (999)\ttotal: 854ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003446668742\n",
"bestIteration = 999\n",
@@ -902,7 +903,7 @@
},
{
"cell_type": "code",
- "execution_count": 71,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -1368,7 +1369,7 @@
"29 10984.582094 9 "
]
},
- "execution_count": 71,
+ "execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@@ -1395,7 +1396,7 @@
},
{
"cell_type": "code",
- "execution_count": 72,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -1433,24 +1434,24 @@
},
{
"cell_type": "code",
- "execution_count": 73,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 1.78ms\tremaining: 1.78s\n",
- "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 511ms\tremaining: 509ms\n",
- "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 1.03s\tremaining: 0us\n",
+ "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 981us\tremaining: 980ms\n",
+ "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 421ms\tremaining: 419ms\n",
+ "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 860ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003367620955\n",
"bestIteration = 999\n",
"\n",
"Fold 0. Score: 0.0033676209549416128\n",
- "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.47ms\tremaining: 1.47s\n",
- "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 431ms\tremaining: 429ms\n",
- "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 857ms\tremaining: 0us\n",
+ "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 2.27ms\tremaining: 2.27s\n",
+ "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 453ms\tremaining: 451ms\n",
+ "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 912ms\tremaining: 0us\n",
"\n",
"bestTest = 0.003339095317\n",
"bestIteration = 999\n",
@@ -1467,7 +1468,7 @@
},
{
"cell_type": "code",
- "execution_count": 74,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -1496,7 +1497,7 @@
},
{
"cell_type": "code",
- "execution_count": 75,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
@@ -1546,205 +1547,75 @@
" \n",
"
\n",
" | 3 | \n",
- " 1 | \n",
- " 2020-01-08 | \n",
- " 2008.084646 | \n",
- "
\n",
- " \n",
- " | 4 | \n",
- " 1 | \n",
- " 2020-01-09 | \n",
- " 2008.425281 | \n",
- "
\n",
- " \n",
- " | 5 | \n",
- " 1 | \n",
- " 2020-01-10 | \n",
- " 2009.385426 | \n",
- "
\n",
- " \n",
- " | 6 | \n",
- " 2 | \n",
- " 2020-01-08 | \n",
- " 3007.931338 | \n",
- "
\n",
- " \n",
- " | 7 | \n",
- " 2 | \n",
- " 2020-01-09 | \n",
- " 3008.294974 | \n",
- "
\n",
- " \n",
- " | 8 | \n",
- " 2 | \n",
- " 2020-01-10 | \n",
- " 3009.266806 | \n",
- "
\n",
- " \n",
- " | 9 | \n",
- " 3 | \n",
- " 2020-01-08 | \n",
- " 4007.738118 | \n",
- "
\n",
- " \n",
- " | 10 | \n",
- " 3 | \n",
- " 2020-01-09 | \n",
- " 4008.102793 | \n",
- "
\n",
- " \n",
- " | 11 | \n",
- " 3 | \n",
- " 2020-01-10 | \n",
- " 4009.074625 | \n",
- "
\n",
- " \n",
- " | 12 | \n",
- " 4 | \n",
- " 2020-01-08 | \n",
- " 5007.752016 | \n",
- "
\n",
- " \n",
- " | 13 | \n",
- " 4 | \n",
- " 2020-01-09 | \n",
- " 5008.118656 | \n",
- "
\n",
- " \n",
- " | 14 | \n",
- " 4 | \n",
- " 2020-01-10 | \n",
- " 5009.092952 | \n",
- "
\n",
- " \n",
- " | 15 | \n",
- " 5 | \n",
- " 2020-01-08 | \n",
- " 6007.75671 | \n",
- "
\n",
- " \n",
- " | 16 | \n",
- " 5 | \n",
- " 2020-01-09 | \n",
- " 6008.12335 | \n",
- "
\n",
- " \n",
- " | 17 | \n",
- " 5 | \n",
- " 2020-01-10 | \n",
- " 6009.096733 | \n",
- "
\n",
- " \n",
- " | 18 | \n",
- " 6 | \n",
- " 2020-01-08 | \n",
- " 7007.789288 | \n",
- "
\n",
- " \n",
- " | 19 | \n",
- " 6 | \n",
- " 2020-01-09 | \n",
- " 7008.155928 | \n",
- "
\n",
- " \n",
- " | 20 | \n",
- " 6 | \n",
- " 2020-01-10 | \n",
- " 7009.129311 | \n",
- "
\n",
- " \n",
- " | 21 | \n",
- " 7 | \n",
- " 2020-01-08 | \n",
- " 8007.861426 | \n",
- "
\n",
- " \n",
- " | 22 | \n",
- " 7 | \n",
+ " 0 | \n",
" 2020-01-09 | \n",
- " 8008.228066 | \n",
+ " 1008.769569 | \n",
"
\n",
" \n",
- " | 23 | \n",
- " 7 | \n",
+ " 4 | \n",
+ " 0 | \n",
" 2020-01-10 | \n",
- " 8009.201449 | \n",
+ " 1009.71512 | \n",
"
\n",
" \n",
- " | 24 | \n",
- " 8 | \n",
- " 2020-01-08 | \n",
- " 9007.995032 | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
"
\n",
" \n",
- " | 25 | \n",
- " 8 | \n",
- " 2020-01-09 | \n",
- " 9008.359142 | \n",
+ " 29725 | \n",
+ " 9 | \n",
+ " 2022-09-24 | \n",
+ " 10996.789104 | \n",
"
\n",
" \n",
- " | 26 | \n",
- " 8 | \n",
- " 2020-01-10 | \n",
- " 9009.332526 | \n",
+ " 29726 | \n",
+ " 9 | \n",
+ " 2022-09-25 | \n",
+ " 10997.211277 | \n",
"
\n",
" \n",
- " | 27 | \n",
+ " 29727 | \n",
" 9 | \n",
- " 2020-01-08 | \n",
- " 10008.069487 | \n",
+ " 2022-09-24 | \n",
+ " 10996.789104 | \n",
"
\n",
" \n",
- " | 28 | \n",
+ " 29728 | \n",
" 9 | \n",
- " 2020-01-09 | \n",
- " 10008.436654 | \n",
+ " 2022-09-25 | \n",
+ " 10997.211277 | \n",
"
\n",
" \n",
- " | 29 | \n",
+ " 29729 | \n",
" 9 | \n",
- " 2020-01-10 | \n",
- " 10009.410037 | \n",
+ " 2022-09-26 | \n",
+ " 10997.953405 | \n",
"
\n",
" \n",
"\n",
+ "29730 rows × 3 columns
\n",
""
],
"text/plain": [
- " id date value\n",
- "0 0 2020-01-08 1008.4361\n",
- "1 0 2020-01-09 1008.769569\n",
- "2 0 2020-01-10 1009.748942\n",
- "3 1 2020-01-08 2008.084646\n",
- "4 1 2020-01-09 2008.425281\n",
- "5 1 2020-01-10 2009.385426\n",
- "6 2 2020-01-08 3007.931338\n",
- "7 2 2020-01-09 3008.294974\n",
- "8 2 2020-01-10 3009.266806\n",
- "9 3 2020-01-08 4007.738118\n",
- "10 3 2020-01-09 4008.102793\n",
- "11 3 2020-01-10 4009.074625\n",
- "12 4 2020-01-08 5007.752016\n",
- "13 4 2020-01-09 5008.118656\n",
- "14 4 2020-01-10 5009.092952\n",
- "15 5 2020-01-08 6007.75671\n",
- "16 5 2020-01-09 6008.12335\n",
- "17 5 2020-01-10 6009.096733\n",
- "18 6 2020-01-08 7007.789288\n",
- "19 6 2020-01-09 7008.155928\n",
- "20 6 2020-01-10 7009.129311\n",
- "21 7 2020-01-08 8007.861426\n",
- "22 7 2020-01-09 8008.228066\n",
- "23 7 2020-01-10 8009.201449\n",
- "24 8 2020-01-08 9007.995032\n",
- "25 8 2020-01-09 9008.359142\n",
- "26 8 2020-01-10 9009.332526\n",
- "27 9 2020-01-08 10008.069487\n",
- "28 9 2020-01-09 10008.436654\n",
- "29 9 2020-01-10 10009.410037"
+ " id date value\n",
+ "0 0 2020-01-08 1008.4361\n",
+ "1 0 2020-01-09 1008.769569\n",
+ "2 0 2020-01-10 1009.748942\n",
+ "3 0 2020-01-09 1008.769569\n",
+ "4 0 2020-01-10 1009.71512\n",
+ "... .. ... ...\n",
+ "29725 9 2022-09-24 10996.789104\n",
+ "29726 9 2022-09-25 10997.211277\n",
+ "29727 9 2022-09-24 10996.789104\n",
+ "29728 9 2022-09-25 10997.211277\n",
+ "29729 9 2022-09-26 10997.953405\n",
+ "\n",
+ "[29730 rows x 3 columns]"
]
},
- "execution_count": 75,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -1799,7 +1670,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -1824,14 +1695,20 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "freq: Month; period: 1.0\n",
+ "freq: Month; period: 1.0\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
"\n",
" It seems that the data is not regular. Please, check the data and the frequency info. \n",
" For multivariate regime it is critical to have regular data.\n",
@@ -1857,7 +1734,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
@@ -1883,13 +1760,345 @@
"source": [
"Now it's all detected correctly."
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "---"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [Subsampling from Dataset](#toc0_)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "strategy = RecursiveStrategy(horizon, history, trainer, pipeline)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0:\tlearn: 1.0234961\ttest: 0.8247936\tbest: 0.8247936 (0)\ttotal: 812us\tremaining: 812ms\n",
+ "500:\tlearn: 0.0222610\ttest: 0.5437334\tbest: 0.5437334 (500)\ttotal: 26.3ms\tremaining: 26.2ms\n",
+ "999:\tlearn: 0.0005205\ttest: 0.5384852\tbest: 0.5384852 (999)\ttotal: 62.1ms\tremaining: 0us\n",
+ "\n",
+ "bestTest = 0.5384851551\n",
+ "bestIteration = 999\n",
+ "\n",
+ "Fold 0. Score: 0.5384851550735784\n",
+ "0:\tlearn: 0.8131044\ttest: 1.0295534\tbest: 1.0295534 (0)\ttotal: 96us\tremaining: 96.2ms\n",
+ "500:\tlearn: 0.0180513\ttest: 0.6959095\tbest: 0.6959095 (500)\ttotal: 40.9ms\tremaining: 40.8ms\n",
+ "999:\tlearn: 0.0004345\ttest: 0.6932387\tbest: 0.6932387 (999)\ttotal: 78.1ms\tremaining: 0us\n",
+ "\n",
+ "bestTest = 0.6932386702\n",
+ "bestIteration = 999\n",
+ "\n",
+ "Fold 1. Score: 0.6932386702268397\n",
+ "Mean score: 0.6159\n",
+ "Std: 0.0774\n"
+ ]
+ }
+ ],
+ "source": [
+ "fit_time, _ = strategy.fit(dataset, subsampling_rate=0.001, subsampling_seed=42)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "freq: Day; period: 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "forecast_time, current_pred = strategy.predict(dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " date | \n",
+ " value | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0 | \n",
+ " 2022-09-27 | \n",
+ " 1639.249496 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0 | \n",
+ " 2022-09-28 | \n",
+ " 1620.282791 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0 | \n",
+ " 2022-09-29 | \n",
+ " 1555.419756 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1 | \n",
+ " 2022-09-27 | \n",
+ " 2639.249496 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 1 | \n",
+ " 2022-09-28 | \n",
+ " 2620.282791 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 1 | \n",
+ " 2022-09-29 | \n",
+ " 2555.419756 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 2 | \n",
+ " 2022-09-27 | \n",
+ " 3639.836497 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 2 | \n",
+ " 2022-09-28 | \n",
+ " 3620.990255 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 2 | \n",
+ " 2022-09-29 | \n",
+ " 3573.857553 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 3 | \n",
+ " 2022-09-27 | \n",
+ " 4635.954675 | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 3 | \n",
+ " 2022-09-28 | \n",
+ " 4614.542752 | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 3 | \n",
+ " 2022-09-29 | \n",
+ " 4546.682434 | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 4 | \n",
+ " 2022-09-27 | \n",
+ " 5642.56791 | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 4 | \n",
+ " 2022-09-28 | \n",
+ " 5622.667564 | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 4 | \n",
+ " 2022-09-29 | \n",
+ " 5573.142299 | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 5 | \n",
+ " 2022-09-27 | \n",
+ " 6642.56791 | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 5 | \n",
+ " 2022-09-28 | \n",
+ " 6622.667564 | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 5 | \n",
+ " 2022-09-29 | \n",
+ " 6573.142299 | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 6 | \n",
+ " 2022-09-27 | \n",
+ " 7638.729381 | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 6 | \n",
+ " 2022-09-28 | \n",
+ " 7619.679409 | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 6 | \n",
+ " 2022-09-29 | \n",
+ " 7553.241637 | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 7 | \n",
+ " 2022-09-27 | \n",
+ " 8636.741438 | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 7 | \n",
+ " 2022-09-28 | \n",
+ " 8614.503698 | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 7 | \n",
+ " 2022-09-29 | \n",
+ " 8548.285629 | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " 8 | \n",
+ " 2022-09-27 | \n",
+ " 9636.741438 | \n",
+ "
\n",
+ " \n",
+ " | 25 | \n",
+ " 8 | \n",
+ " 2022-09-28 | \n",
+ " 9614.503698 | \n",
+ "
\n",
+ " \n",
+ " | 26 | \n",
+ " 8 | \n",
+ " 2022-09-29 | \n",
+ " 9548.285629 | \n",
+ "
\n",
+ " \n",
+ " | 27 | \n",
+ " 9 | \n",
+ " 2022-09-27 | \n",
+ " 10636.741438 | \n",
+ "
\n",
+ " \n",
+ " | 28 | \n",
+ " 9 | \n",
+ " 2022-09-28 | \n",
+ " 10614.503698 | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " 9 | \n",
+ " 2022-09-29 | \n",
+ " 10548.285629 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id date value\n",
+ "0 0 2022-09-27 1639.249496\n",
+ "1 0 2022-09-28 1620.282791\n",
+ "2 0 2022-09-29 1555.419756\n",
+ "3 1 2022-09-27 2639.249496\n",
+ "4 1 2022-09-28 2620.282791\n",
+ "5 1 2022-09-29 2555.419756\n",
+ "6 2 2022-09-27 3639.836497\n",
+ "7 2 2022-09-28 3620.990255\n",
+ "8 2 2022-09-29 3573.857553\n",
+ "9 3 2022-09-27 4635.954675\n",
+ "10 3 2022-09-28 4614.542752\n",
+ "11 3 2022-09-29 4546.682434\n",
+ "12 4 2022-09-27 5642.56791\n",
+ "13 4 2022-09-28 5622.667564\n",
+ "14 4 2022-09-29 5573.142299\n",
+ "15 5 2022-09-27 6642.56791\n",
+ "16 5 2022-09-28 6622.667564\n",
+ "17 5 2022-09-29 6573.142299\n",
+ "18 6 2022-09-27 7638.729381\n",
+ "19 6 2022-09-28 7619.679409\n",
+ "20 6 2022-09-29 7553.241637\n",
+ "21 7 2022-09-27 8636.741438\n",
+ "22 7 2022-09-28 8614.503698\n",
+ "23 7 2022-09-29 8548.285629\n",
+ "24 8 2022-09-27 9636.741438\n",
+ "25 8 2022-09-28 9614.503698\n",
+ "26 8 2022-09-29 9548.285629\n",
+ "27 9 2022-09-27 10636.741438\n",
+ "28 9 2022-09-28 10614.503698\n",
+ "29 9 2022-09-29 10548.285629"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "current_pred"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python (tsururu-dev)",
+ "display_name": ".venv",
"language": "python",
- "name": "tsururu-dev"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
diff --git a/examples/Tutorial_2_Strategies.ipynb b/examples/Tutorial_2_Strategies.ipynb
index 2b11bbf..ea3c790 100644
--- a/examples/Tutorial_2_Strategies.ipynb
+++ b/examples/Tutorial_2_Strategies.ipynb
@@ -72,7 +72,7 @@
"import pandas as pd\n",
"\n",
"from tsururu.dataset import IndexSlicer, Pipeline, TSDataset\n",
- "from tsururu.model_training import DLTrainer, KFoldCrossValidator\n",
+ "#from tsururu.model_training import DLTrainer, KFoldCrossValidator\n",
"from tsururu.model_training.validator import Validator\n",
"from tsururu.models import CatBoost, Estimator\n",
"from tsururu.strategies.base import Strategy\n",
@@ -94,9 +94,22 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'DLTrainer' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 67\u001b[39m\n\u001b[32m 63\u001b[39m y_pred = y_pred.reshape(pipeline.y_original_shape)\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m y_pred\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[38;5;250;43m \u001b[39;49m\u001b[34;43;01mRecursiveStrategy\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43mStrategy\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mdef\u001b[39;49;00m\u001b[38;5;250;43m \u001b[39;49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 70\u001b[39m \u001b[43m \u001b[49m\u001b[43mhorizon\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m (...)\u001b[39m\u001b[32m 76\u001b[39m \u001b[43m \u001b[49m\u001b[43mreduced\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 77\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 78\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mhorizon\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhistory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpipeline\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstep\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 72\u001b[39m, in \u001b[36mRecursiveStrategy\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mRecursiveStrategy\u001b[39;00m(Strategy):\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\n\u001b[32m 69\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 70\u001b[39m horizon: \u001b[38;5;28mint\u001b[39m,\n\u001b[32m 71\u001b[39m history: \u001b[38;5;28mint\u001b[39m,\n\u001b[32m---> \u001b[39m\u001b[32m72\u001b[39m trainer: Union[MLTrainer, \u001b[43mDLTrainer\u001b[49m],\n\u001b[32m 73\u001b[39m pipeline: Pipeline,\n\u001b[32m 74\u001b[39m step: \u001b[38;5;28mint\u001b[39m = \u001b[32m1\u001b[39m,\n\u001b[32m 75\u001b[39m model_horizon: \u001b[38;5;28mint\u001b[39m = \u001b[32m1\u001b[39m,\n\u001b[32m 76\u001b[39m reduced: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 77\u001b[39m ):\n\u001b[32m 78\u001b[39m \u001b[38;5;28msuper\u001b[39m().\u001b[34m__init__\u001b[39m(horizon, history, trainer, pipeline, step)\n\u001b[32m 79\u001b[39m \u001b[38;5;28mself\u001b[39m.model_horizon = model_horizon\n",
+ "\u001b[31mNameError\u001b[39m: name 'DLTrainer' is not defined"
+ ]
+ }
+ ],
"source": [
"class MLTrainer:\n",
" def __init__(\n",
@@ -628,7 +641,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -856,7 +869,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -866,7 +879,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -902,7 +915,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1081,7 +1094,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1097,7 +1110,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1630,7 +1643,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -2052,7 +2065,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -2178,7 +2191,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -2194,7 +2207,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -2751,7 +2764,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -3010,7 +3023,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -3150,7 +3163,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -10200,7 +10213,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -10342,7 +10355,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -10358,7 +10371,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -10891,7 +10904,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -12546,7 +12559,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -12686,7 +12699,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -12703,7 +12716,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -13236,7 +13249,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -13658,7 +13671,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -13791,7 +13804,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -13808,7 +13821,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -14341,7 +14354,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -15174,7 +15187,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -15321,7 +15334,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -15337,7 +15350,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -16662,7 +16675,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -18317,7 +18330,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -18450,7 +18463,7 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -18467,7 +18480,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -19785,7 +19798,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -21446,7 +21459,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -21579,7 +21592,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -21596,7 +21609,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -22268,7 +22281,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -23101,7 +23114,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -23242,7 +23255,7 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -23258,7 +23271,7 @@
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -23683,7 +23696,7 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -23997,7 +24010,7 @@
},
{
"cell_type": "code",
- "execution_count": 46,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
diff --git a/pyproject.toml b/pyproject.toml
index a779a28..ce761b9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "tsururu"
-version = "1.1.0"
+version = "1.1.1"
description = "Python tool for time series forecasting"
readme = "README.md"
requires-python = ">=3.9,<3.14"
diff --git a/tsururu/models/torch_based/layers/convolution.py b/tsururu/models/torch_based/layers/convolution.py
index df9c0ea..b8bd449 100644
--- a/tsururu/models/torch_based/layers/convolution.py
+++ b/tsururu/models/torch_based/layers/convolution.py
@@ -4,9 +4,10 @@
torch = OptionalImport("torch")
nn = OptionalImport("torch.nn")
+Module = OptionalImport("torch.nn.Module")
-class Inception_Block_V1(nn.Module):
+class Inception_Block_V1(Module):
"""Inception Block Version 1.
Args:
@@ -43,7 +44,7 @@ def _initialize_weights(self):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the Inception block.
Args:
diff --git a/tsururu/models/torch_based/layers/embedding.py b/tsururu/models/torch_based/layers/embedding.py
index eb71db3..f8ac3ea 100644
--- a/tsururu/models/torch_based/layers/embedding.py
+++ b/tsururu/models/torch_based/layers/embedding.py
@@ -7,9 +7,10 @@
torch = OptionalImport("torch")
nn = OptionalImport("torch.nn")
+Module = OptionalImport("torch.nn.Module")
-class TokenEmbedding(nn.Module):
+class TokenEmbedding(Module):
"""Token embedding layer using 1D convolution.
Args:
@@ -33,7 +34,7 @@ def __init__(self, c_in: int, d_model: int):
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu")
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the token embedding.
Args:
@@ -47,7 +48,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
-class PositionalEmbedding(nn.Module):
+class PositionalEmbedding(Module):
"""Positional encoding using sine and cosine functions.
Args:
@@ -71,7 +72,7 @@ def __init__(self, d_model: int, max_len: int = 5000):
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the positional embedding.
Args:
@@ -84,7 +85,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.pe[:, : x.size(1)]
-class FixedEmbedding(nn.Module):
+class FixedEmbedding(Module):
"""Fixed embedding layer using precomputed sine and cosine values.
Args:
@@ -108,7 +109,7 @@ def __init__(self, c_in: int, d_model: int):
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the fixed embedding.
Args:
@@ -121,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.emb(x).detach()
-class TemporalEmbedding(nn.Module):
+class TemporalEmbedding(Module):
"""Temporal embedding layer for time-related features.
Args:
@@ -148,7 +149,7 @@ def __init__(self, d_model: int, embed_type: str = "fixed", freq: str = "h"):
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the temporal embedding.
Args:
@@ -168,7 +169,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return hour_x + weekday_x + day_x + month_x + minute_x
-class TimeFeatureEmbedding(nn.Module):
+class TimeFeatureEmbedding(Module):
"""Time feature embedding layer using linear transformation.
Args:
@@ -181,7 +182,7 @@ def __init__(self, d_model: int, d_inp: int = 1):
super(TimeFeatureEmbedding, self).__init__()
self.embed = nn.Linear(d_inp, d_model, bias=False)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the time feature embedding.
Args:
@@ -194,7 +195,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.embed(x)
-class Embedding(nn.Module):
+class Embedding(Module):
"""Data embedding layer combining token, positional, and temporal embeddings.
Args:
@@ -242,7 +243,7 @@ def __init__(
self.dropout = nn.Dropout(p=dropout)
- def forward(self, x: torch.Tensor, x_mark: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor", x_mark: Optional["torch.Tensor"] = None) -> "torch.Tensor":
"""Forward pass of the data embedding.
Args:
diff --git a/tsururu/models/torch_based/layers/patch_tst.py b/tsururu/models/torch_based/layers/patch_tst.py
index f2c601e..af7b705 100644
--- a/tsururu/models/torch_based/layers/patch_tst.py
+++ b/tsururu/models/torch_based/layers/patch_tst.py
@@ -12,7 +12,6 @@
torch = OptionalImport("torch")
nn = OptionalImport("torch.nn")
F = OptionalImport("torch.nn.functional")
-Tensor = OptionalImport("torch.Tensor")
Module = OptionalImport("torch.nn.Module")
@@ -82,7 +81,7 @@ def __init__(
act: str = "gelu",
key_padding_mask: Union[bool, str] = "auto",
padding_var: Optional[int] = None,
- attn_mask: Optional[Tensor] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
res_attention: bool = True,
pre_norm: bool = False,
store_attn: bool = False,
@@ -166,7 +165,7 @@ def __init__(
head_dropout=head_dropout,
)
- def forward(self, z: Tensor) -> Tensor:
+ def forward(self, z: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the PatchTST backbone.
Args:
@@ -330,7 +329,7 @@ def __init__(
store_attn: bool = False,
key_padding_mask: Union[bool, str] = "auto",
padding_var: Optional[int] = None,
- attn_mask: Optional[Tensor] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
res_attention: bool = True,
pre_norm: bool = False,
pe: str = "zeros",
@@ -383,7 +382,7 @@ def __init__(
store_attn=store_attn,
)
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass of the TSTi encoder.
Args:
@@ -423,7 +422,7 @@ def forward(self, x: Tensor) -> Tensor:
return x
@staticmethod
- def _reorganize_tensor(x: Tensor, n_vars: int) -> Tensor:
+ def _reorganize_tensor(x: "torch.Tensor", n_vars: int) -> "torch.Tensor":
"""Reorganizes input tensor to pair time series channels with exogenous features.
Args:
@@ -520,10 +519,10 @@ def __init__(
def forward(
self,
- src: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- ) -> Tensor:
+ src: "torch.Tensor",
+ key_padding_mask: Optional["torch.Tensor"] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
+ ) -> "torch.Tensor":
"""Forward pass of the TST encoder.
Args:
@@ -639,11 +638,11 @@ def __init__(
def forward(
self,
- src: Tensor,
- prev: Optional[Tensor] = None,
- key_padding_mask: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
+ src: "torch.Tensor",
+ prev: Optional["torch.Tensor"] = None,
+ key_padding_mask: Optional["torch.Tensor"] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
+ ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor"]]:
"""Forward pass of the TST encoder layer.
Args:
@@ -747,13 +746,13 @@ def __init__(
def forward(
self,
- Q: Tensor,
- K: Optional[Tensor] = None,
- V: Optional[Tensor] = None,
- prev: Optional[Tensor] = None,
- key_padding_mask: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
+ Q: "torch.Tensor",
+ K: Optional["torch.Tensor"] = None,
+ V: Optional["torch.Tensor"] = None,
+ prev: Optional["torch.Tensor"] = None,
+ key_padding_mask: Optional["torch.Tensor"] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
+ ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor"]]:
"""Forward pass of the multi-head attention layer.
Args:
@@ -845,13 +844,13 @@ def __init__(
def forward(
self,
- q: Tensor,
- k: Tensor,
- v: Tensor,
- prev: Optional[Tensor] = None,
- key_padding_mask: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
+ q: "torch.Tensor",
+ k: "torch.Tensor",
+ v: "torch.Tensor",
+ prev: Optional["torch.Tensor"] = None,
+ key_padding_mask: Optional["torch.Tensor"] = None,
+ attn_mask: Optional["torch.Tensor"] = None,
+ ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]]:
"""Forward pass of the scaled dot-product attention.
Args:
diff --git a/tsururu/models/torch_based/layers/positional_encoding.py b/tsururu/models/torch_based/layers/positional_encoding.py
index 3155bac..8b3831b 100644
--- a/tsururu/models/torch_based/layers/positional_encoding.py
+++ b/tsururu/models/torch_based/layers/positional_encoding.py
@@ -78,9 +78,7 @@ def Coord2dPosEncoding(
return cpe
-def Coord1dPosEncoding(
- q_len: int, exponential: bool = False, normalize: bool = True
-) -> "torch.Tensor":
+def Coord1dPosEncoding(q_len: int, exponential: bool = False, normalize: bool = True) -> "torch.Tensor":
"""Generate 1D coordinate positional encoding.
Args:
diff --git a/tsururu/models/torch_based/times_net.py b/tsururu/models/torch_based/times_net.py
index ecc8f94..94d134b 100644
--- a/tsururu/models/torch_based/times_net.py
+++ b/tsururu/models/torch_based/times_net.py
@@ -14,9 +14,10 @@
nn = OptionalImport("torch.nn")
F = OptionalImport("torch.nn.functional")
rearrange = OptionalImport("einops.rearrange")
+Module = OptionalImport("torch.nn.Module")
-def FFT_for_Period(x: torch.Tensor, k: int = 2) -> Tuple[np.ndarray, torch.Tensor]:
+def FFT_for_Period(x: "torch.Tensor", k: int = 2) -> Tuple[np.ndarray, "torch.Tensor"]:
"""Compute the FFT for the input tensor and find the top-k periods.
Args:
@@ -41,7 +42,7 @@ def FFT_for_Period(x: torch.Tensor, k: int = 2) -> Tuple[np.ndarray, torch.Tenso
return period, abs(xf).mean(-1)[:, top_list]
-class TimesBlock(nn.Module):
+class TimesBlock(Module):
"""TimesBlock module for time series forecasting.
Args:
@@ -69,7 +70,7 @@ def __init__(
Inception_Block_V1(d_ff, d_model, num_kernels=num_kernels),
)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Forward pass for the TimesBlock module.
Args:
diff --git a/tsururu/models/torch_based/utils.py b/tsururu/models/torch_based/utils.py
index 735a899..a93294b 100644
--- a/tsururu/models/torch_based/utils.py
+++ b/tsururu/models/torch_based/utils.py
@@ -33,10 +33,10 @@ def adjust_features_groups(features_groups: Dict[str, int], num_lags: int) -> Di
def slice_features(
- X: torch.Tensor,
+ X: "torch.Tensor",
feature_list: List[str],
features_groups_corrected,
-) -> torch.Tensor:
+) -> "torch.Tensor":
"""Slice the input tensor X based on the corrected feature groups.
Args:
@@ -88,11 +88,11 @@ def slice_features(
def slice_features_4d(
- X: torch.Tensor,
+ X: "torch.Tensor",
features_list: List[str],
features_groups_corrected,
num_series,
-) -> torch.Tensor:
+) -> "torch.Tensor":
"""Slice the input tensor X based on the corrected feature groups and reshape it to 4D."""
groups_order: List[str] = [
"series",
diff --git a/tsururu/strategies/direct.py b/tsururu/strategies/direct.py
index 56644d3..29c6ba6 100644
--- a/tsururu/strategies/direct.py
+++ b/tsururu/strategies/direct.py
@@ -1,6 +1,8 @@
from copy import deepcopy
from typing import Union
+import numpy as np
+
from tsururu.dataset.dataset import TSDataset
from tsururu.dataset.pipeline import Pipeline
from tsururu.dataset.slice import IndexSlicer
@@ -56,6 +58,8 @@ def __init__(
def fit(
self,
dataset: TSDataset,
+ subsampling_rate: float = 1.0,
+ subsampling_seed: int = 42,
) -> "DirectStrategy":
"""Fits the direct strategy to the given dataset.
@@ -92,6 +96,15 @@ def fit(
delta=dataset.delta,
)
+ if subsampling_rate < 1.0:
+ all_idx = np.arange(features_idx.shape[0])
+ np.random.seed(subsampling_seed)
+ sampled_idx = np.random.choice(
+ all_idx, size=int(subsampling_rate * len(all_idx)), replace=False
+ )
+ features_idx = features_idx[sampled_idx]
+ target_idx = target_idx[sampled_idx]
+
data = self.pipeline.create_data_dict_for_pipeline(dataset, features_idx, target_idx)
data = self.pipeline.fit_transform(data, self.strategy_name)
diff --git a/tsururu/strategies/recursive.py b/tsururu/strategies/recursive.py
index 2cb9ee6..0184221 100644
--- a/tsururu/strategies/recursive.py
+++ b/tsururu/strategies/recursive.py
@@ -1,6 +1,7 @@
from copy import deepcopy
from typing import Union
+import numpy as np
import pandas as pd
from tsururu.dataset.dataset import TSDataset
@@ -60,11 +61,15 @@ def __init__(
def fit(
self,
dataset: TSDataset,
+ subsampling_rate: float = 1.0,
+ subsampling_seed: int = 42,
) -> "RecursiveStrategy":
"""Fits the recursive strategy to the given dataset.
Args:
dataset: The dataset to fit the strategy on.
+ subsampling_rate: The rate at which to subsample the data for training.
+ A value of 1.0 means no subsampling.
Returns:
self.
@@ -88,6 +93,15 @@ def fit(
delta=dataset.delta,
)
+ if subsampling_rate < 1.0:
+ all_idx = np.arange(features_idx.shape[0])
+ np.random.seed(subsampling_seed)
+ sampled_idx = np.random.choice(
+ all_idx, size=int(subsampling_rate * len(all_idx)), replace=False
+ )
+ features_idx = features_idx[sampled_idx]
+ target_idx = target_idx[sampled_idx]
+
data = self.pipeline.create_data_dict_for_pipeline(dataset, features_idx, target_idx)
data = self.pipeline.fit_transform(data, self.strategy_name)