diff --git a/examples/Tutorial_1_Quick_start.ipynb b/examples/Tutorial_1_Quick_start.ipynb index 751fc53..4fd0c82 100644 --- a/examples/Tutorial_1_Quick_start.ipynb +++ b/examples/Tutorial_1_Quick_start.ipynb @@ -17,15 +17,16 @@ " - [Strategy](#toc2_1_2_6_) \n", " - [Backtest validation of pipeline](#toc2_2_) \n", " - [Sliding Window Validation](#toc2_3_) \n", - " - [Working with raw time series' granularity](#toc2_4_) \n", + " - [Working with raw time series' granularity](#toc2_4_)\n", + " - [Subsampling from Dataset](#toc2_5_) \n", "\n", "\n", + " numbering=false\n", + " anchor=true\n", + " flat=false\n", + " minLevel=1\n", + " maxLevel=6\n", + " /vscode-jupyter-toc-config -->\n", "" ] }, @@ -52,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -153,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -273,7 +274,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -290,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -306,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -352,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -377,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -424,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -483,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -493,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -502,24 +503,24 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 2.64ms\tremaining: 2.64s\n", - "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 458ms\tremaining: 456ms\n", - "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 973ms\tremaining: 0us\n", + "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 58.2ms\tremaining: 58.1s\n", + "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 544ms\tremaining: 542ms\n", + "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 988ms\tremaining: 0us\n", "\n", "bestTest = 0.003367620955\n", "bestIteration = 999\n", "\n", "Fold 0. Score: 0.0033676209549416128\n", - "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.56ms\tremaining: 1.56s\n", - "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 445ms\tremaining: 443ms\n", - "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 912ms\tremaining: 0us\n", + "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.03ms\tremaining: 1.03s\n", + "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 554ms\tremaining: 552ms\n", + "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 1.03s\tremaining: 0us\n", "\n", "bestTest = 0.003339095317\n", "bestIteration = 999\n", @@ -536,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -553,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -801,7 +802,7 @@ "29 9 2022-09-29 10981.801303" ] }, - "execution_count": 69, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -826,7 +827,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -834,17 +835,17 @@ "output_type": "stream", "text": [ "freq: Day; period: 1\n", - "0:\tlearn: 0.9618043\ttest: 0.9656878\tbest: 0.9656878 (0)\ttotal: 2.18ms\tremaining: 2.18s\n", - "500:\tlearn: 0.0051787\ttest: 0.0052308\tbest: 0.0052308 (500)\ttotal: 455ms\tremaining: 453ms\n", - "999:\tlearn: 0.0030887\ttest: 0.0032962\tbest: 0.0032962 (999)\ttotal: 956ms\tremaining: 0us\n", + "0:\tlearn: 0.9618043\ttest: 0.9656878\tbest: 0.9656878 (0)\ttotal: 870us\tremaining: 870ms\n", + "500:\tlearn: 0.0051787\ttest: 0.0052308\tbest: 0.0052308 (500)\ttotal: 532ms\tremaining: 530ms\n", + "999:\tlearn: 0.0030887\ttest: 0.0032962\tbest: 0.0032962 (999)\ttotal: 985ms\tremaining: 0us\n", "\n", "bestTest = 0.00329621851\n", "bestIteration = 999\n", "\n", "Fold 0. Score: 0.0032962185095815415\n", - "0:\tlearn: 0.9647141\ttest: 0.9623340\tbest: 0.9623340 (0)\ttotal: 1.5ms\tremaining: 1.5s\n", - "500:\tlearn: 0.0055104\ttest: 0.0057187\tbest: 0.0057187 (500)\ttotal: 440ms\tremaining: 438ms\n", - "999:\tlearn: 0.0033140\ttest: 0.0035696\tbest: 0.0035696 (999)\ttotal: 866ms\tremaining: 0us\n", + "0:\tlearn: 0.9647141\ttest: 0.9623340\tbest: 0.9623340 (0)\ttotal: 1.11ms\tremaining: 1.11s\n", + "500:\tlearn: 0.0055104\ttest: 0.0057187\tbest: 0.0057187 (500)\ttotal: 427ms\tremaining: 425ms\n", + "999:\tlearn: 0.0033140\ttest: 0.0035696\tbest: 0.0035696 (999)\ttotal: 861ms\tremaining: 0us\n", "\n", "bestTest = 0.003569589613\n", "bestIteration = 999\n", @@ -854,17 +855,17 @@ "Std: 0.0001\n", "freq: Day; period: 1\n", "freq: Day; period: 1\n", - "0:\tlearn: 0.9635630\ttest: 0.9636501\tbest: 0.9636501 (0)\ttotal: 959us\tremaining: 958ms\n", - "500:\tlearn: 0.0051732\ttest: 0.0053000\tbest: 0.0053000 (500)\ttotal: 446ms\tremaining: 444ms\n", - "999:\tlearn: 0.0031204\ttest: 0.0032928\tbest: 0.0032928 (999)\ttotal: 889ms\tremaining: 0us\n", + "0:\tlearn: 0.9635630\ttest: 0.9636501\tbest: 0.9636501 (0)\ttotal: 1.02ms\tremaining: 1.01s\n", + "500:\tlearn: 0.0051732\ttest: 0.0053000\tbest: 0.0053000 (500)\ttotal: 449ms\tremaining: 447ms\n", + "999:\tlearn: 0.0031204\ttest: 0.0032928\tbest: 0.0032928 (999)\ttotal: 878ms\tremaining: 0us\n", "\n", "bestTest = 0.003292811776\n", "bestIteration = 999\n", "\n", "Fold 0. Score: 0.0032928117763904234\n", - "0:\tlearn: 0.9632624\ttest: 0.9635496\tbest: 0.9635496 (0)\ttotal: 2.48ms\tremaining: 2.48s\n", - "500:\tlearn: 0.0053458\ttest: 0.0056056\tbest: 0.0056056 (500)\ttotal: 545ms\tremaining: 543ms\n", - "999:\tlearn: 0.0032541\ttest: 0.0035628\tbest: 0.0035628 (999)\ttotal: 979ms\tremaining: 0us\n", + "0:\tlearn: 0.9632624\ttest: 0.9635496\tbest: 0.9635496 (0)\ttotal: 1.02ms\tremaining: 1.02s\n", + "500:\tlearn: 0.0053458\ttest: 0.0056056\tbest: 0.0056056 (500)\ttotal: 432ms\tremaining: 431ms\n", + "999:\tlearn: 0.0032541\ttest: 0.0035628\tbest: 0.0035628 (999)\ttotal: 866ms\tremaining: 0us\n", "\n", "bestTest = 0.003562794598\n", "bestIteration = 999\n", @@ -874,17 +875,17 @@ "Std: 0.0001\n", "freq: Day; period: 1\n", "freq: Day; period: 1\n", - "0:\tlearn: 0.9672679\ttest: 0.9599529\tbest: 0.9599529 (0)\ttotal: 1.4ms\tremaining: 1.4s\n", - "500:\tlearn: 0.0052718\ttest: 0.0054426\tbest: 0.0054426 (500)\ttotal: 447ms\tremaining: 445ms\n", - "999:\tlearn: 0.0030843\ttest: 0.0033111\tbest: 0.0033111 (999)\ttotal: 860ms\tremaining: 0us\n", + "0:\tlearn: 0.9672679\ttest: 0.9599529\tbest: 0.9599529 (0)\ttotal: 903us\tremaining: 902ms\n", + "500:\tlearn: 0.0052718\ttest: 0.0054426\tbest: 0.0054426 (500)\ttotal: 421ms\tremaining: 420ms\n", + "999:\tlearn: 0.0030843\ttest: 0.0033111\tbest: 0.0033111 (999)\ttotal: 865ms\tremaining: 0us\n", "\n", "bestTest = 0.003311056772\n", "bestIteration = 999\n", "\n", "Fold 0. Score: 0.003311056771697718\n", - "0:\tlearn: 0.9591189\ttest: 0.9679140\tbest: 0.9679140 (0)\ttotal: 1.44ms\tremaining: 1.44s\n", - "500:\tlearn: 0.0053303\ttest: 0.0056656\tbest: 0.0056656 (500)\ttotal: 429ms\tremaining: 427ms\n", - "999:\tlearn: 0.0031389\ttest: 0.0034467\tbest: 0.0034467 (999)\ttotal: 982ms\tremaining: 0us\n", + "0:\tlearn: 0.9591189\ttest: 0.9679140\tbest: 0.9679140 (0)\ttotal: 1.16ms\tremaining: 1.16s\n", + "500:\tlearn: 0.0053303\ttest: 0.0056656\tbest: 0.0056656 (500)\ttotal: 431ms\tremaining: 429ms\n", + "999:\tlearn: 0.0031389\ttest: 0.0034467\tbest: 0.0034467 (999)\ttotal: 854ms\tremaining: 0us\n", "\n", "bestTest = 0.003446668742\n", "bestIteration = 999\n", @@ -902,7 +903,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1368,7 +1369,7 @@ "29 10984.582094 9 " ] }, - "execution_count": 71, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1395,7 +1396,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -1433,24 +1434,24 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 1.78ms\tremaining: 1.78s\n", - "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 511ms\tremaining: 509ms\n", - "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 1.03s\tremaining: 0us\n", + "0:\tlearn: 0.9606080\ttest: 0.9667407\tbest: 0.9667407 (0)\ttotal: 981us\tremaining: 980ms\n", + "500:\tlearn: 0.0051947\ttest: 0.0053699\tbest: 0.0053699 (500)\ttotal: 421ms\tremaining: 419ms\n", + "999:\tlearn: 0.0031608\ttest: 0.0033676\tbest: 0.0033676 (999)\ttotal: 860ms\tremaining: 0us\n", "\n", "bestTest = 0.003367620955\n", "bestIteration = 999\n", "\n", "Fold 0. Score: 0.0033676209549416128\n", - "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 1.47ms\tremaining: 1.47s\n", - "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 431ms\tremaining: 429ms\n", - "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 857ms\tremaining: 0us\n", + "0:\tlearn: 0.9659554\ttest: 0.9614093\tbest: 0.9614093 (0)\ttotal: 2.27ms\tremaining: 2.27s\n", + "500:\tlearn: 0.0052698\ttest: 0.0054766\tbest: 0.0054766 (500)\ttotal: 453ms\tremaining: 451ms\n", + "999:\tlearn: 0.0031515\ttest: 0.0033391\tbest: 0.0033391 (999)\ttotal: 912ms\tremaining: 0us\n", "\n", "bestTest = 0.003339095317\n", "bestIteration = 999\n", @@ -1467,7 +1468,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1496,7 +1497,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1546,205 +1547,75 @@ " \n", " \n", " 3\n", - " 1\n", - " 2020-01-08\n", - " 2008.084646\n", - " \n", - " \n", - " 4\n", - " 1\n", - " 2020-01-09\n", - " 2008.425281\n", - " \n", - " \n", - " 5\n", - " 1\n", - " 2020-01-10\n", - " 2009.385426\n", - " \n", - " \n", - " 6\n", - " 2\n", - " 2020-01-08\n", - " 3007.931338\n", - " \n", - " \n", - " 7\n", - " 2\n", - " 2020-01-09\n", - " 3008.294974\n", - " \n", - " \n", - " 8\n", - " 2\n", - " 2020-01-10\n", - " 3009.266806\n", - " \n", - " \n", - " 9\n", - " 3\n", - " 2020-01-08\n", - " 4007.738118\n", - " \n", - " \n", - " 10\n", - " 3\n", - " 2020-01-09\n", - " 4008.102793\n", - " \n", - " \n", - " 11\n", - " 3\n", - " 2020-01-10\n", - " 4009.074625\n", - " \n", - " \n", - " 12\n", - " 4\n", - " 2020-01-08\n", - " 5007.752016\n", - " \n", - " \n", - " 13\n", - " 4\n", - " 2020-01-09\n", - " 5008.118656\n", - " \n", - " \n", - " 14\n", - " 4\n", - " 2020-01-10\n", - " 5009.092952\n", - " \n", - " \n", - " 15\n", - " 5\n", - " 2020-01-08\n", - " 6007.75671\n", - " \n", - " \n", - " 16\n", - " 5\n", - " 2020-01-09\n", - " 6008.12335\n", - " \n", - " \n", - " 17\n", - " 5\n", - " 2020-01-10\n", - " 6009.096733\n", - " \n", - " \n", - " 18\n", - " 6\n", - " 2020-01-08\n", - " 7007.789288\n", - " \n", - " \n", - " 19\n", - " 6\n", - " 2020-01-09\n", - " 7008.155928\n", - " \n", - " \n", - " 20\n", - " 6\n", - " 2020-01-10\n", - " 7009.129311\n", - " \n", - " \n", - " 21\n", - " 7\n", - " 2020-01-08\n", - " 8007.861426\n", - " \n", - " \n", - " 22\n", - " 7\n", + " 0\n", " 2020-01-09\n", - " 8008.228066\n", + " 1008.769569\n", " \n", " \n", - " 23\n", - " 7\n", + " 4\n", + " 0\n", " 2020-01-10\n", - " 8009.201449\n", + " 1009.71512\n", " \n", " \n", - " 24\n", - " 8\n", - " 2020-01-08\n", - " 9007.995032\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", - " 25\n", - " 8\n", - " 2020-01-09\n", - " 9008.359142\n", + " 29725\n", + " 9\n", + " 2022-09-24\n", + " 10996.789104\n", " \n", " \n", - " 26\n", - " 8\n", - " 2020-01-10\n", - " 9009.332526\n", + " 29726\n", + " 9\n", + " 2022-09-25\n", + " 10997.211277\n", " \n", " \n", - " 27\n", + " 29727\n", " 9\n", - " 2020-01-08\n", - " 10008.069487\n", + " 2022-09-24\n", + " 10996.789104\n", " \n", " \n", - " 28\n", + " 29728\n", " 9\n", - " 2020-01-09\n", - " 10008.436654\n", + " 2022-09-25\n", + " 10997.211277\n", " \n", " \n", - " 29\n", + " 29729\n", " 9\n", - " 2020-01-10\n", - " 10009.410037\n", + " 2022-09-26\n", + " 10997.953405\n", " \n", " \n", "\n", + "

29730 rows × 3 columns

\n", "" ], "text/plain": [ - " id date value\n", - "0 0 2020-01-08 1008.4361\n", - "1 0 2020-01-09 1008.769569\n", - "2 0 2020-01-10 1009.748942\n", - "3 1 2020-01-08 2008.084646\n", - "4 1 2020-01-09 2008.425281\n", - "5 1 2020-01-10 2009.385426\n", - "6 2 2020-01-08 3007.931338\n", - "7 2 2020-01-09 3008.294974\n", - "8 2 2020-01-10 3009.266806\n", - "9 3 2020-01-08 4007.738118\n", - "10 3 2020-01-09 4008.102793\n", - "11 3 2020-01-10 4009.074625\n", - "12 4 2020-01-08 5007.752016\n", - "13 4 2020-01-09 5008.118656\n", - "14 4 2020-01-10 5009.092952\n", - "15 5 2020-01-08 6007.75671\n", - "16 5 2020-01-09 6008.12335\n", - "17 5 2020-01-10 6009.096733\n", - "18 6 2020-01-08 7007.789288\n", - "19 6 2020-01-09 7008.155928\n", - "20 6 2020-01-10 7009.129311\n", - "21 7 2020-01-08 8007.861426\n", - "22 7 2020-01-09 8008.228066\n", - "23 7 2020-01-10 8009.201449\n", - "24 8 2020-01-08 9007.995032\n", - "25 8 2020-01-09 9008.359142\n", - "26 8 2020-01-10 9009.332526\n", - "27 9 2020-01-08 10008.069487\n", - "28 9 2020-01-09 10008.436654\n", - "29 9 2020-01-10 10009.410037" + " id date value\n", + "0 0 2020-01-08 1008.4361\n", + "1 0 2020-01-09 1008.769569\n", + "2 0 2020-01-10 1009.748942\n", + "3 0 2020-01-09 1008.769569\n", + "4 0 2020-01-10 1009.71512\n", + "... .. ... ...\n", + "29725 9 2022-09-24 10996.789104\n", + "29726 9 2022-09-25 10997.211277\n", + "29727 9 2022-09-24 10996.789104\n", + "29728 9 2022-09-25 10997.211277\n", + "29729 9 2022-09-26 10997.953405\n", + "\n", + "[29730 rows x 3 columns]" ] }, - "execution_count": 75, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1799,7 +1670,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1824,14 +1695,20 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "freq: Month; period: 1.0\n", + "freq: Month; period: 1.0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", " It seems that the data is not regular. Please, check the data and the frequency info. \n", " For multivariate regime it is critical to have regular data.\n", @@ -1857,7 +1734,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1883,13 +1760,345 @@ "source": [ "Now it's all detected correctly." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Subsampling from Dataset](#toc0_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "strategy = RecursiveStrategy(horizon, history, trainer, pipeline)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0:\tlearn: 1.0234961\ttest: 0.8247936\tbest: 0.8247936 (0)\ttotal: 812us\tremaining: 812ms\n", + "500:\tlearn: 0.0222610\ttest: 0.5437334\tbest: 0.5437334 (500)\ttotal: 26.3ms\tremaining: 26.2ms\n", + "999:\tlearn: 0.0005205\ttest: 0.5384852\tbest: 0.5384852 (999)\ttotal: 62.1ms\tremaining: 0us\n", + "\n", + "bestTest = 0.5384851551\n", + "bestIteration = 999\n", + "\n", + "Fold 0. Score: 0.5384851550735784\n", + "0:\tlearn: 0.8131044\ttest: 1.0295534\tbest: 1.0295534 (0)\ttotal: 96us\tremaining: 96.2ms\n", + "500:\tlearn: 0.0180513\ttest: 0.6959095\tbest: 0.6959095 (500)\ttotal: 40.9ms\tremaining: 40.8ms\n", + "999:\tlearn: 0.0004345\ttest: 0.6932387\tbest: 0.6932387 (999)\ttotal: 78.1ms\tremaining: 0us\n", + "\n", + "bestTest = 0.6932386702\n", + "bestIteration = 999\n", + "\n", + "Fold 1. Score: 0.6932386702268397\n", + "Mean score: 0.6159\n", + "Std: 0.0774\n" + ] + } + ], + "source": [ + "fit_time, _ = strategy.fit(dataset, subsampling_rate=0.001, subsampling_seed=42)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "freq: Day; period: 1\n" + ] + } + ], + "source": [ + "forecast_time, current_pred = strategy.predict(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iddatevalue
002022-09-271639.249496
102022-09-281620.282791
202022-09-291555.419756
312022-09-272639.249496
412022-09-282620.282791
512022-09-292555.419756
622022-09-273639.836497
722022-09-283620.990255
822022-09-293573.857553
932022-09-274635.954675
1032022-09-284614.542752
1132022-09-294546.682434
1242022-09-275642.56791
1342022-09-285622.667564
1442022-09-295573.142299
1552022-09-276642.56791
1652022-09-286622.667564
1752022-09-296573.142299
1862022-09-277638.729381
1962022-09-287619.679409
2062022-09-297553.241637
2172022-09-278636.741438
2272022-09-288614.503698
2372022-09-298548.285629
2482022-09-279636.741438
2582022-09-289614.503698
2682022-09-299548.285629
2792022-09-2710636.741438
2892022-09-2810614.503698
2992022-09-2910548.285629
\n", + "
" + ], + "text/plain": [ + " id date value\n", + "0 0 2022-09-27 1639.249496\n", + "1 0 2022-09-28 1620.282791\n", + "2 0 2022-09-29 1555.419756\n", + "3 1 2022-09-27 2639.249496\n", + "4 1 2022-09-28 2620.282791\n", + "5 1 2022-09-29 2555.419756\n", + "6 2 2022-09-27 3639.836497\n", + "7 2 2022-09-28 3620.990255\n", + "8 2 2022-09-29 3573.857553\n", + "9 3 2022-09-27 4635.954675\n", + "10 3 2022-09-28 4614.542752\n", + "11 3 2022-09-29 4546.682434\n", + "12 4 2022-09-27 5642.56791\n", + "13 4 2022-09-28 5622.667564\n", + "14 4 2022-09-29 5573.142299\n", + "15 5 2022-09-27 6642.56791\n", + "16 5 2022-09-28 6622.667564\n", + "17 5 2022-09-29 6573.142299\n", + "18 6 2022-09-27 7638.729381\n", + "19 6 2022-09-28 7619.679409\n", + "20 6 2022-09-29 7553.241637\n", + "21 7 2022-09-27 8636.741438\n", + "22 7 2022-09-28 8614.503698\n", + "23 7 2022-09-29 8548.285629\n", + "24 8 2022-09-27 9636.741438\n", + "25 8 2022-09-28 9614.503698\n", + "26 8 2022-09-29 9548.285629\n", + "27 9 2022-09-27 10636.741438\n", + "28 9 2022-09-28 10614.503698\n", + "29 9 2022-09-29 10548.285629" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "current_pred" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python (tsururu-dev)", + "display_name": ".venv", "language": "python", - "name": "tsururu-dev" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/examples/Tutorial_2_Strategies.ipynb b/examples/Tutorial_2_Strategies.ipynb index 2b11bbf..ea3c790 100644 --- a/examples/Tutorial_2_Strategies.ipynb +++ b/examples/Tutorial_2_Strategies.ipynb @@ -72,7 +72,7 @@ "import pandas as pd\n", "\n", "from tsururu.dataset import IndexSlicer, Pipeline, TSDataset\n", - "from tsururu.model_training import DLTrainer, KFoldCrossValidator\n", + "#from tsururu.model_training import DLTrainer, KFoldCrossValidator\n", "from tsururu.model_training.validator import Validator\n", "from tsururu.models import CatBoost, Estimator\n", "from tsururu.strategies.base import Strategy\n", @@ -94,9 +94,22 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'DLTrainer' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 67\u001b[39m\n\u001b[32m 63\u001b[39m y_pred = y_pred.reshape(pipeline.y_original_shape)\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m y_pred\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[38;5;250;43m \u001b[39;49m\u001b[34;43;01mRecursiveStrategy\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43mStrategy\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mdef\u001b[39;49;00m\u001b[38;5;250;43m \u001b[39;49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 70\u001b[39m \u001b[43m \u001b[49m\u001b[43mhorizon\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m (...)\u001b[39m\u001b[32m 76\u001b[39m \u001b[43m \u001b[49m\u001b[43mreduced\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 77\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 78\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mhorizon\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhistory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpipeline\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstep\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 72\u001b[39m, in \u001b[36mRecursiveStrategy\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mRecursiveStrategy\u001b[39;00m(Strategy):\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\n\u001b[32m 69\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 70\u001b[39m horizon: \u001b[38;5;28mint\u001b[39m,\n\u001b[32m 71\u001b[39m history: \u001b[38;5;28mint\u001b[39m,\n\u001b[32m---> \u001b[39m\u001b[32m72\u001b[39m trainer: Union[MLTrainer, \u001b[43mDLTrainer\u001b[49m],\n\u001b[32m 73\u001b[39m pipeline: Pipeline,\n\u001b[32m 74\u001b[39m step: \u001b[38;5;28mint\u001b[39m = \u001b[32m1\u001b[39m,\n\u001b[32m 75\u001b[39m model_horizon: \u001b[38;5;28mint\u001b[39m = \u001b[32m1\u001b[39m,\n\u001b[32m 76\u001b[39m reduced: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 77\u001b[39m ):\n\u001b[32m 78\u001b[39m \u001b[38;5;28msuper\u001b[39m().\u001b[34m__init__\u001b[39m(horizon, history, trainer, pipeline, step)\n\u001b[32m 79\u001b[39m \u001b[38;5;28mself\u001b[39m.model_horizon = model_horizon\n", + "\u001b[31mNameError\u001b[39m: name 'DLTrainer' is not defined" + ] + } + ], "source": [ "class MLTrainer:\n", " def __init__(\n", @@ -628,7 +641,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -856,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -866,7 +879,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -902,7 +915,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1081,7 +1094,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1097,7 +1110,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1630,7 +1643,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2052,7 +2065,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2178,7 +2191,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2194,7 +2207,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2751,7 +2764,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -3010,7 +3023,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -3150,7 +3163,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -10200,7 +10213,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -10342,7 +10355,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -10358,7 +10371,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -10891,7 +10904,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -12546,7 +12559,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -12686,7 +12699,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12703,7 +12716,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -13236,7 +13249,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -13658,7 +13671,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -13791,7 +13804,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -13808,7 +13821,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -14341,7 +14354,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -15174,7 +15187,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -15321,7 +15334,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -15337,7 +15350,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -16662,7 +16675,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -18317,7 +18330,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -18450,7 +18463,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -18467,7 +18480,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -19785,7 +19798,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -21446,7 +21459,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -21579,7 +21592,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -21596,7 +21609,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -22268,7 +22281,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -23101,7 +23114,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -23242,7 +23255,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -23258,7 +23271,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -23683,7 +23696,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -23997,7 +24010,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/pyproject.toml b/pyproject.toml index a779a28..ce761b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tsururu" -version = "1.1.0" +version = "1.1.1" description = "Python tool for time series forecasting" readme = "README.md" requires-python = ">=3.9,<3.14" diff --git a/tsururu/models/torch_based/layers/convolution.py b/tsururu/models/torch_based/layers/convolution.py index df9c0ea..b8bd449 100644 --- a/tsururu/models/torch_based/layers/convolution.py +++ b/tsururu/models/torch_based/layers/convolution.py @@ -4,9 +4,10 @@ torch = OptionalImport("torch") nn = OptionalImport("torch.nn") +Module = OptionalImport("torch.nn.Module") -class Inception_Block_V1(nn.Module): +class Inception_Block_V1(Module): """Inception Block Version 1. Args: @@ -43,7 +44,7 @@ def _initialize_weights(self): if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the Inception block. Args: diff --git a/tsururu/models/torch_based/layers/embedding.py b/tsururu/models/torch_based/layers/embedding.py index eb71db3..f8ac3ea 100644 --- a/tsururu/models/torch_based/layers/embedding.py +++ b/tsururu/models/torch_based/layers/embedding.py @@ -7,9 +7,10 @@ torch = OptionalImport("torch") nn = OptionalImport("torch.nn") +Module = OptionalImport("torch.nn.Module") -class TokenEmbedding(nn.Module): +class TokenEmbedding(Module): """Token embedding layer using 1D convolution. Args: @@ -33,7 +34,7 @@ def __init__(self, c_in: int, d_model: int): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the token embedding. Args: @@ -47,7 +48,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class PositionalEmbedding(nn.Module): +class PositionalEmbedding(Module): """Positional encoding using sine and cosine functions. Args: @@ -71,7 +72,7 @@ def __init__(self, d_model: int, max_len: int = 5000): pe = pe.unsqueeze(0) self.register_buffer("pe", pe) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the positional embedding. Args: @@ -84,7 +85,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.pe[:, : x.size(1)] -class FixedEmbedding(nn.Module): +class FixedEmbedding(Module): """Fixed embedding layer using precomputed sine and cosine values. Args: @@ -108,7 +109,7 @@ def __init__(self, c_in: int, d_model: int): self.emb = nn.Embedding(c_in, d_model) self.emb.weight = nn.Parameter(w, requires_grad=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the fixed embedding. Args: @@ -121,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.emb(x).detach() -class TemporalEmbedding(nn.Module): +class TemporalEmbedding(Module): """Temporal embedding layer for time-related features. Args: @@ -148,7 +149,7 @@ def __init__(self, d_model: int, embed_type: str = "fixed", freq: str = "h"): self.day_embed = Embed(day_size, d_model) self.month_embed = Embed(month_size, d_model) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the temporal embedding. Args: @@ -168,7 +169,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return hour_x + weekday_x + day_x + month_x + minute_x -class TimeFeatureEmbedding(nn.Module): +class TimeFeatureEmbedding(Module): """Time feature embedding layer using linear transformation. Args: @@ -181,7 +182,7 @@ def __init__(self, d_model: int, d_inp: int = 1): super(TimeFeatureEmbedding, self).__init__() self.embed = nn.Linear(d_inp, d_model, bias=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the time feature embedding. Args: @@ -194,7 +195,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.embed(x) -class Embedding(nn.Module): +class Embedding(Module): """Data embedding layer combining token, positional, and temporal embeddings. Args: @@ -242,7 +243,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) - def forward(self, x: torch.Tensor, x_mark: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: "torch.Tensor", x_mark: Optional["torch.Tensor"] = None) -> "torch.Tensor": """Forward pass of the data embedding. Args: diff --git a/tsururu/models/torch_based/layers/patch_tst.py b/tsururu/models/torch_based/layers/patch_tst.py index f2c601e..af7b705 100644 --- a/tsururu/models/torch_based/layers/patch_tst.py +++ b/tsururu/models/torch_based/layers/patch_tst.py @@ -12,7 +12,6 @@ torch = OptionalImport("torch") nn = OptionalImport("torch.nn") F = OptionalImport("torch.nn.functional") -Tensor = OptionalImport("torch.Tensor") Module = OptionalImport("torch.nn.Module") @@ -82,7 +81,7 @@ def __init__( act: str = "gelu", key_padding_mask: Union[bool, str] = "auto", padding_var: Optional[int] = None, - attn_mask: Optional[Tensor] = None, + attn_mask: Optional["torch.Tensor"] = None, res_attention: bool = True, pre_norm: bool = False, store_attn: bool = False, @@ -166,7 +165,7 @@ def __init__( head_dropout=head_dropout, ) - def forward(self, z: Tensor) -> Tensor: + def forward(self, z: "torch.Tensor") -> "torch.Tensor": """Forward pass of the PatchTST backbone. Args: @@ -330,7 +329,7 @@ def __init__( store_attn: bool = False, key_padding_mask: Union[bool, str] = "auto", padding_var: Optional[int] = None, - attn_mask: Optional[Tensor] = None, + attn_mask: Optional["torch.Tensor"] = None, res_attention: bool = True, pre_norm: bool = False, pe: str = "zeros", @@ -383,7 +382,7 @@ def __init__( store_attn=store_attn, ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass of the TSTi encoder. Args: @@ -423,7 +422,7 @@ def forward(self, x: Tensor) -> Tensor: return x @staticmethod - def _reorganize_tensor(x: Tensor, n_vars: int) -> Tensor: + def _reorganize_tensor(x: "torch.Tensor", n_vars: int) -> "torch.Tensor": """Reorganizes input tensor to pair time series channels with exogenous features. Args: @@ -520,10 +519,10 @@ def __init__( def forward( self, - src: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: + src: "torch.Tensor", + key_padding_mask: Optional["torch.Tensor"] = None, + attn_mask: Optional["torch.Tensor"] = None, + ) -> "torch.Tensor": """Forward pass of the TST encoder. Args: @@ -639,11 +638,11 @@ def __init__( def forward( self, - src: Tensor, - prev: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + src: "torch.Tensor", + prev: Optional["torch.Tensor"] = None, + key_padding_mask: Optional["torch.Tensor"] = None, + attn_mask: Optional["torch.Tensor"] = None, + ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor"]]: """Forward pass of the TST encoder layer. Args: @@ -747,13 +746,13 @@ def __init__( def forward( self, - Q: Tensor, - K: Optional[Tensor] = None, - V: Optional[Tensor] = None, - prev: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + Q: "torch.Tensor", + K: Optional["torch.Tensor"] = None, + V: Optional["torch.Tensor"] = None, + prev: Optional["torch.Tensor"] = None, + key_padding_mask: Optional["torch.Tensor"] = None, + attn_mask: Optional["torch.Tensor"] = None, + ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor"]]: """Forward pass of the multi-head attention layer. Args: @@ -845,13 +844,13 @@ def __init__( def forward( self, - q: Tensor, - k: Tensor, - v: Tensor, - prev: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + q: "torch.Tensor", + k: "torch.Tensor", + v: "torch.Tensor", + prev: Optional["torch.Tensor"] = None, + key_padding_mask: Optional["torch.Tensor"] = None, + attn_mask: Optional["torch.Tensor"] = None, + ) -> Union["torch.Tensor", Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]]: """Forward pass of the scaled dot-product attention. Args: diff --git a/tsururu/models/torch_based/layers/positional_encoding.py b/tsururu/models/torch_based/layers/positional_encoding.py index 3155bac..8b3831b 100644 --- a/tsururu/models/torch_based/layers/positional_encoding.py +++ b/tsururu/models/torch_based/layers/positional_encoding.py @@ -78,9 +78,7 @@ def Coord2dPosEncoding( return cpe -def Coord1dPosEncoding( - q_len: int, exponential: bool = False, normalize: bool = True -) -> "torch.Tensor": +def Coord1dPosEncoding(q_len: int, exponential: bool = False, normalize: bool = True) -> "torch.Tensor": """Generate 1D coordinate positional encoding. Args: diff --git a/tsururu/models/torch_based/times_net.py b/tsururu/models/torch_based/times_net.py index ecc8f94..94d134b 100644 --- a/tsururu/models/torch_based/times_net.py +++ b/tsururu/models/torch_based/times_net.py @@ -14,9 +14,10 @@ nn = OptionalImport("torch.nn") F = OptionalImport("torch.nn.functional") rearrange = OptionalImport("einops.rearrange") +Module = OptionalImport("torch.nn.Module") -def FFT_for_Period(x: torch.Tensor, k: int = 2) -> Tuple[np.ndarray, torch.Tensor]: +def FFT_for_Period(x: "torch.Tensor", k: int = 2) -> Tuple[np.ndarray, "torch.Tensor"]: """Compute the FFT for the input tensor and find the top-k periods. Args: @@ -41,7 +42,7 @@ def FFT_for_Period(x: torch.Tensor, k: int = 2) -> Tuple[np.ndarray, torch.Tenso return period, abs(xf).mean(-1)[:, top_list] -class TimesBlock(nn.Module): +class TimesBlock(Module): """TimesBlock module for time series forecasting. Args: @@ -69,7 +70,7 @@ def __init__( Inception_Block_V1(d_ff, d_model, num_kernels=num_kernels), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: "torch.Tensor") -> "torch.Tensor": """Forward pass for the TimesBlock module. Args: diff --git a/tsururu/models/torch_based/utils.py b/tsururu/models/torch_based/utils.py index 735a899..a93294b 100644 --- a/tsururu/models/torch_based/utils.py +++ b/tsururu/models/torch_based/utils.py @@ -33,10 +33,10 @@ def adjust_features_groups(features_groups: Dict[str, int], num_lags: int) -> Di def slice_features( - X: torch.Tensor, + X: "torch.Tensor", feature_list: List[str], features_groups_corrected, -) -> torch.Tensor: +) -> "torch.Tensor": """Slice the input tensor X based on the corrected feature groups. Args: @@ -88,11 +88,11 @@ def slice_features( def slice_features_4d( - X: torch.Tensor, + X: "torch.Tensor", features_list: List[str], features_groups_corrected, num_series, -) -> torch.Tensor: +) -> "torch.Tensor": """Slice the input tensor X based on the corrected feature groups and reshape it to 4D.""" groups_order: List[str] = [ "series", diff --git a/tsururu/strategies/direct.py b/tsururu/strategies/direct.py index 56644d3..29c6ba6 100644 --- a/tsururu/strategies/direct.py +++ b/tsururu/strategies/direct.py @@ -1,6 +1,8 @@ from copy import deepcopy from typing import Union +import numpy as np + from tsururu.dataset.dataset import TSDataset from tsururu.dataset.pipeline import Pipeline from tsururu.dataset.slice import IndexSlicer @@ -56,6 +58,8 @@ def __init__( def fit( self, dataset: TSDataset, + subsampling_rate: float = 1.0, + subsampling_seed: int = 42, ) -> "DirectStrategy": """Fits the direct strategy to the given dataset. @@ -92,6 +96,15 @@ def fit( delta=dataset.delta, ) + if subsampling_rate < 1.0: + all_idx = np.arange(features_idx.shape[0]) + np.random.seed(subsampling_seed) + sampled_idx = np.random.choice( + all_idx, size=int(subsampling_rate * len(all_idx)), replace=False + ) + features_idx = features_idx[sampled_idx] + target_idx = target_idx[sampled_idx] + data = self.pipeline.create_data_dict_for_pipeline(dataset, features_idx, target_idx) data = self.pipeline.fit_transform(data, self.strategy_name) diff --git a/tsururu/strategies/recursive.py b/tsururu/strategies/recursive.py index 2cb9ee6..0184221 100644 --- a/tsururu/strategies/recursive.py +++ b/tsururu/strategies/recursive.py @@ -1,6 +1,7 @@ from copy import deepcopy from typing import Union +import numpy as np import pandas as pd from tsururu.dataset.dataset import TSDataset @@ -60,11 +61,15 @@ def __init__( def fit( self, dataset: TSDataset, + subsampling_rate: float = 1.0, + subsampling_seed: int = 42, ) -> "RecursiveStrategy": """Fits the recursive strategy to the given dataset. Args: dataset: The dataset to fit the strategy on. + subsampling_rate: The rate at which to subsample the data for training. + A value of 1.0 means no subsampling. Returns: self. @@ -88,6 +93,15 @@ def fit( delta=dataset.delta, ) + if subsampling_rate < 1.0: + all_idx = np.arange(features_idx.shape[0]) + np.random.seed(subsampling_seed) + sampled_idx = np.random.choice( + all_idx, size=int(subsampling_rate * len(all_idx)), replace=False + ) + features_idx = features_idx[sampled_idx] + target_idx = target_idx[sampled_idx] + data = self.pipeline.create_data_dict_for_pipeline(dataset, features_idx, target_idx) data = self.pipeline.fit_transform(data, self.strategy_name)