|
18 | 18 | "#hide\n",
|
19 | 19 | "import torch\n",
|
20 | 20 | "import torch.nn as nn\n",
|
21 |
| - "import torch.nn.functional as F\n", |
22 |
| - "import numpy as np\n", |
23 |
| - "from collections import OrderedDict\n", |
24 | 21 | "\n",
|
25 | 22 | "from model_constructor.net import Net\n",
|
26 |
| - "from model_constructor.layers import ConvLayer, noop, act_fn, SimpleSelfAttention\n", |
27 |
| - "\n", |
28 |
| - "from nbdev.showdoc import show_doc\n", |
29 |
| - "from IPython.display import Markdown, display" |
30 |
| - ] |
31 |
| - }, |
32 |
| - { |
33 |
| - "cell_type": "code", |
34 |
| - "execution_count": null, |
35 |
| - "metadata": {}, |
36 |
| - "outputs": [], |
37 |
| - "source": [ |
38 |
| - "# hide\n", |
39 |
| - "def print_doc(func_name):\n", |
40 |
| - " doc = show_doc(func_name, title_level=4, disp=False)\n", |
41 |
| - " display(Markdown(doc))" |
| 23 | + "from model_constructor.layers import ConvLayer\n" |
42 | 24 | ]
|
43 | 25 | },
|
44 | 26 | {
|
|
57 | 39 | "from model_constructor.twist import ConvTwist"
|
58 | 40 | ]
|
59 | 41 | },
|
60 |
| - { |
61 |
| - "cell_type": "code", |
62 |
| - "execution_count": null, |
63 |
| - "metadata": {}, |
64 |
| - "outputs": [], |
65 |
| - "source": [ |
66 |
| - "# hide_input\n", |
67 |
| - "# print_doc(ConvTwist)" |
68 |
| - ] |
69 |
| - }, |
70 | 42 | {
|
71 | 43 | "cell_type": "code",
|
72 | 44 | "execution_count": null,
|
|
150 | 122 | }
|
151 | 123 | ],
|
152 | 124 | "source": [
|
153 |
| - "ConvTwist(64,64)" |
| 125 | + "ConvTwist(64, 64)" |
154 | 126 | ]
|
155 | 127 | },
|
156 | 128 | {
|
|
175 | 147 | "source": [
|
176 | 148 | "ConvTwist.twist = True\n",
|
177 | 149 | "ConvTwist.permute = False\n",
|
178 |
| - "ConvTwist(64,64)" |
| 150 | + "ConvTwist(64, 64)" |
179 | 151 | ]
|
180 | 152 | },
|
181 | 153 | {
|
|
219 | 191 | }
|
220 | 192 | ],
|
221 | 193 | "source": [
|
222 |
| - "ConvLayerTwist(64,64, stride=1)" |
| 194 | + "ConvLayerTwist(64, 64, stride=1)" |
223 | 195 | ]
|
224 | 196 | },
|
225 | 197 | {
|
|
558 | 530 | "from model_constructor.twist import NewResBlockTwist"
|
559 | 531 | ]
|
560 | 532 | },
|
561 |
| - { |
562 |
| - "cell_type": "code", |
563 |
| - "execution_count": null, |
564 |
| - "metadata": {}, |
565 |
| - "outputs": [], |
566 |
| - "source": [ |
567 |
| - "# hide_input\n", |
568 |
| - "# print_doc(NewResBlockTwist)" |
569 |
| - ] |
570 |
| - }, |
571 | 533 | {
|
572 | 534 | "cell_type": "code",
|
573 | 535 | "execution_count": null,
|
|
610 | 572 | ],
|
611 | 573 | "source": [
|
612 | 574 | "#collapse_output\n",
|
613 |
| - "bl = NewResBlockTwist(4,64,64,sa=True)\n", |
| 575 | + "bl = NewResBlockTwist(4, 64, 64, sa=True)\n", |
614 | 576 | "bl"
|
615 | 577 | ]
|
616 | 578 | },
|
|
676 | 638 | ],
|
677 | 639 | "source": [
|
678 | 640 | "#collapse_output\n",
|
679 |
| - "bl = NewResBlockTwist(4,64,64,stride=2)\n", |
| 641 | + "bl = NewResBlockTwist(4, 64, 64, stride=2)\n", |
680 | 642 | "bl"
|
681 | 643 | ]
|
682 | 644 | },
|
|
746 | 708 | ],
|
747 | 709 | "source": [
|
748 | 710 | "#collapse_output\n",
|
749 |
| - "bl = NewResBlockTwist(4,64,128,stride=2)\n", |
| 711 | + "bl = NewResBlockTwist(4, 64, 128, stride=2)\n", |
750 | 712 | "bl"
|
751 | 713 | ]
|
752 | 714 | },
|
|
806 | 768 | ],
|
807 | 769 | "source": [
|
808 | 770 | "#hide\n",
|
809 |
| - "bl = NewResBlockTwist(1,64,64,sa=True)\n", |
| 771 | + "bl = NewResBlockTwist(1, 64, 64, sa=True)\n", |
810 | 772 | "bl"
|
811 | 773 | ]
|
812 | 774 | },
|
|
876 | 838 | ],
|
877 | 839 | "source": [
|
878 | 840 | "#collapse_output\n",
|
879 |
| - "bl = NewResBlockTwist(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n", |
| 841 | + "bl = NewResBlockTwist(\n", |
| 842 | + " 4,\n", |
| 843 | + " 64,\n", |
| 844 | + " 128,\n", |
| 845 | + " stride=2,\n", |
| 846 | + " act_fn=nn.LeakyReLU(),\n", |
| 847 | + " bn_1st=False,\n", |
| 848 | + ")\n", |
880 | 849 | "bl"
|
881 | 850 | ]
|
882 | 851 | },
|
|
918 | 887 | "from model_constructor.twist import ResBlockTwist"
|
919 | 888 | ]
|
920 | 889 | },
|
921 |
| - { |
922 |
| - "cell_type": "code", |
923 |
| - "execution_count": null, |
924 |
| - "metadata": {}, |
925 |
| - "outputs": [], |
926 |
| - "source": [ |
927 |
| - "#hide_input\n", |
928 |
| - "# print_doc(ResBlockTwist)" |
929 |
| - ] |
930 |
| - }, |
931 | 890 | {
|
932 | 891 | "cell_type": "code",
|
933 | 892 | "execution_count": null,
|
|
970 | 929 | ],
|
971 | 930 | "source": [
|
972 | 931 | "#collapse_output\n",
|
973 |
| - "bl = ResBlockTwist(4,64,64,sa=True)\n", |
| 932 | + "bl = ResBlockTwist(4, 64, 64, sa=True)\n", |
974 | 933 | "bl"
|
975 | 934 | ]
|
976 | 935 | },
|
|
1036 | 995 | ],
|
1037 | 996 | "source": [
|
1038 | 997 | "#collapse_output\n",
|
1039 |
| - "bl = ResBlockTwist(4,64,64,stride=2)\n", |
| 998 | + "bl = ResBlockTwist(4, 64, 64, stride=2)\n", |
1040 | 999 | "bl"
|
1041 | 1000 | ]
|
1042 | 1001 | },
|
|
1106 | 1065 | ],
|
1107 | 1066 | "source": [
|
1108 | 1067 | "#collapse_output\n",
|
1109 |
| - "bl = ResBlockTwist(4,64,128,stride=2)\n", |
| 1068 | + "bl = ResBlockTwist(4, 64, 128, stride=2)\n", |
1110 | 1069 | "bl"
|
1111 | 1070 | ]
|
1112 | 1071 | },
|
|
1145 | 1104 | "metadata": {},
|
1146 | 1105 | "outputs": [],
|
1147 | 1106 | "source": [
|
1148 |
| - "model = Net(expansion=4, layers=[3,4,6,3])" |
| 1107 | + "model = Net(expansion=4, layers=[3, 4, 6, 3])" |
1149 | 1108 | ]
|
1150 | 1109 | },
|
1151 | 1110 | {
|
|
1569 | 1528 | "text": [
|
1570 | 1529 | "torch.Size([16, 64, 32, 32])\n"
|
1571 | 1530 | ]
|
1572 |
| - }, |
1573 |
| - { |
1574 |
| - "name": "stderr", |
1575 |
| - "output_type": "stream", |
1576 |
| - "text": [ |
1577 |
| - "/home/jzz/anaconda3/envs/mc_dev/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/c10/core/TensorImpl.h:1156.)\n", |
1578 |
| - " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" |
1579 |
| - ] |
1580 | 1531 | }
|
1581 | 1532 | ],
|
1582 | 1533 | "source": [
|
|
0 commit comments