Skip to content

Commit 92de2d9

Browse files
authored
Merge pull request #85 from ayasyrev/ayasyrev/issue78
docs yaresnet
2 parents 85aa2a9 + ba63db7 commit 92de2d9

File tree

6 files changed

+346
-51
lines changed

6 files changed

+346
-51
lines changed

Nbs/04_YaResNet.ipynb

+35-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@
1313
"cell_type": "code",
1414
"execution_count": null,
1515
"metadata": {},
16-
"outputs": [],
16+
"outputs": [
17+
{
18+
"name": "stdout",
19+
"output_type": "stream",
20+
"text": [
21+
"The autoreload extension is already loaded. To reload it, use:\n",
22+
" %reload_ext autoreload\n"
23+
]
24+
}
25+
],
1726
"source": [
1827
"#hide\n",
1928
"%load_ext autoreload\n",
@@ -1001,13 +1010,27 @@
10011010
"# YaResnet34, YaResnet50"
10021011
]
10031012
},
1013+
{
1014+
"attachments": {},
1015+
"cell_type": "markdown",
1016+
"metadata": {},
1017+
"source": [
1018+
"We has `Resnet34` and `Resnet50` like models predefined, we can impoer it as: \n",
1019+
"`from model_constructor.yaresnet import YaResNet34, YaResNet50` \n",
1020+
"But lets create it."
1021+
]
1022+
},
10041023
{
10051024
"cell_type": "code",
10061025
"execution_count": null,
10071026
"metadata": {},
10081027
"outputs": [],
10091028
"source": [
1010-
"from model_constructor.yaresnet import YaResNet34, YaResNet50"
1029+
"class YaResNet34(ModelConstructor):\n",
1030+
" block: type[nn.Module] = YaResBlock\n",
1031+
" expansion: int = 1\n",
1032+
" layers: list[int] = [3, 4, 6, 3]\n",
1033+
" act_fn: type[nn.Module] = nn.Mish"
10111034
]
10121035
},
10131036
{
@@ -1037,6 +1060,16 @@
10371060
"yaresnet34"
10381061
]
10391062
},
1063+
{
1064+
"cell_type": "code",
1065+
"execution_count": null,
1066+
"metadata": {},
1067+
"outputs": [],
1068+
"source": [
1069+
"class YaResNet50(YaResNet34):\n",
1070+
" expansion: int = 4"
1071+
]
1072+
},
10401073
{
10411074
"cell_type": "code",
10421075
"execution_count": null,

Nbs/index.ipynb Nbs/README.ipynb

+131-16
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"outputs": [],
1919
"source": [
2020
"#hide\n",
21-
"import torch"
21+
"import torch\n",
22+
"from torch import nn"
2223
]
2324
},
2425
{
@@ -1326,9 +1327,27 @@
13261327
"metadata": {},
13271328
"outputs": [],
13281329
"source": [
1330+
"mc = ModelConstructor(name=\"YaResNet\")\n",
13291331
"mc.block = YaResBlock"
13301332
]
13311333
},
1334+
{
1335+
"attachments": {},
1336+
"cell_type": "markdown",
1337+
"metadata": {},
1338+
"source": [
1339+
"Or in one line:"
1340+
]
1341+
},
1342+
{
1343+
"cell_type": "code",
1344+
"execution_count": null,
1345+
"metadata": {},
1346+
"outputs": [],
1347+
"source": [
1348+
"mc = ModelConstructor(name=\"YaResNet\", block=YaResBlock)"
1349+
]
1350+
},
13321351
{
13331352
"cell_type": "markdown",
13341353
"metadata": {},
@@ -1352,15 +1371,15 @@
13521371
" block='YaResBlock'\n",
13531372
" conv_layer='ConvBnAct'\n",
13541373
" block_sizes=[64, 128, 256, 512]\n",
1355-
" layers=[3, 4, 6, 3]\n",
1374+
" layers=[2, 2, 2, 2]\n",
13561375
" norm='BatchNorm2d'\n",
1357-
" act_fn='Mish'\n",
1376+
" act_fn='ReLU'\n",
13581377
" pool=\"AvgPool2d {'kernel_size': 2, 'ceil_mode': True}\"\n",
1359-
" expansion=4\n",
1378+
" expansion=1\n",
13601379
" groups=1\n",
13611380
" bn_1st=True\n",
13621381
" zero_bn=True\n",
1363-
" stem_sizes=[32, 64, 64]\n",
1382+
" stem_sizes=[32, 32, 64]\n",
13641383
" stem_pool=\"MaxPool2d {'kernel_size': 3, 'stride': 2, 'padding': 1}\"\n",
13651384
" init_cnn='init_cnn'\n",
13661385
" make_stem='make_stem'\n",
@@ -1372,7 +1391,6 @@
13721391
],
13731392
"source": [
13741393
"#collapse_output\n",
1375-
"mc.name = 'YaResNet'\n",
13761394
"mc.print_cfg()"
13771395
]
13781396
},
@@ -1395,25 +1413,20 @@
13951413
" (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
13961414
" (convs): Sequential(\n",
13971415
" (conv_0): ConvBnAct(\n",
1398-
" (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
1416+
" (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
13991417
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1400-
" (act_fn): Mish(inplace=True)\n",
1418+
" (act_fn): ReLU(inplace=True)\n",
14011419
" )\n",
14021420
" (conv_1): ConvBnAct(\n",
14031421
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
14041422
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1405-
" (act_fn): Mish(inplace=True)\n",
1406-
" )\n",
1407-
" (conv_2): ConvBnAct(\n",
1408-
" (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
1409-
" (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
14101423
" )\n",
14111424
" )\n",
14121425
" (id_conv): ConvBnAct(\n",
1413-
" (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
1414-
" (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1426+
" (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
1427+
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
14151428
" )\n",
1416-
" (merge): Mish(inplace=True)\n",
1429+
" (merge): ReLU(inplace=True)\n",
14171430
")"
14181431
]
14191432
},
@@ -1426,6 +1439,108 @@
14261439
"#collapse_output\n",
14271440
"mc.body.l_1.bl_0"
14281441
]
1442+
},
1443+
{
1444+
"attachments": {},
1445+
"cell_type": "markdown",
1446+
"metadata": {},
1447+
"source": [
1448+
"Lets create `Resnet34` like model constructor:"
1449+
]
1450+
},
1451+
{
1452+
"cell_type": "code",
1453+
"execution_count": null,
1454+
"metadata": {},
1455+
"outputs": [],
1456+
"source": [
1457+
"class YaResnet34(ModelConstructor):\n",
1458+
" block: type[nn.Module] = YaResBlock\n",
1459+
" layers: list[int] = [3, 4, 6, 3]"
1460+
]
1461+
},
1462+
{
1463+
"cell_type": "code",
1464+
"execution_count": null,
1465+
"metadata": {},
1466+
"outputs": [
1467+
{
1468+
"name": "stdout",
1469+
"output_type": "stream",
1470+
"text": [
1471+
"YaResnet34(\n",
1472+
" in_chans=3\n",
1473+
" num_classes=1000\n",
1474+
" block='YaResBlock'\n",
1475+
" conv_layer='ConvBnAct'\n",
1476+
" block_sizes=[64, 128, 256, 512]\n",
1477+
" layers=[3, 4, 6, 3]\n",
1478+
" norm='BatchNorm2d'\n",
1479+
" act_fn='ReLU'\n",
1480+
" pool=\"AvgPool2d {'kernel_size': 2, 'ceil_mode': True}\"\n",
1481+
" expansion=1\n",
1482+
" groups=1\n",
1483+
" bn_1st=True\n",
1484+
" zero_bn=True\n",
1485+
" stem_sizes=[32, 32, 64]\n",
1486+
" stem_pool=\"MaxPool2d {'kernel_size': 3, 'stride': 2, 'padding': 1}\"\n",
1487+
" init_cnn='init_cnn'\n",
1488+
" make_stem='make_stem'\n",
1489+
" make_layer='make_layer'\n",
1490+
" make_body='make_body'\n",
1491+
" make_head='make_head')\n"
1492+
]
1493+
}
1494+
],
1495+
"source": [
1496+
"mc = YaResnet34()\n",
1497+
"mc.print_cfg()"
1498+
]
1499+
},
1500+
{
1501+
"attachments": {},
1502+
"cell_type": "markdown",
1503+
"metadata": {},
1504+
"source": [
1505+
"And `Resnet50` like model can be inherited from `YaResnet34`:"
1506+
]
1507+
},
1508+
{
1509+
"cell_type": "code",
1510+
"execution_count": null,
1511+
"metadata": {},
1512+
"outputs": [],
1513+
"source": [
1514+
"class YaResnet50(YaResnet34):\n",
1515+
" expansion = 4"
1516+
]
1517+
},
1518+
{
1519+
"cell_type": "code",
1520+
"execution_count": null,
1521+
"metadata": {},
1522+
"outputs": [
1523+
{
1524+
"data": {
1525+
"text/plain": [
1526+
"YaResnet50\n",
1527+
" in_chans: 3, num_classes: 1000\n",
1528+
" expansion: 4, groups: 1, dw: False, div_groups: None\n",
1529+
" act_fn: ReLU, sa: False, se: False\n",
1530+
" stem sizes: [32, 32, 64], stride on 0\n",
1531+
" body sizes [64, 128, 256, 512]\n",
1532+
" layers: [3, 4, 6, 3]"
1533+
]
1534+
},
1535+
"execution_count": null,
1536+
"metadata": {},
1537+
"output_type": "execute_result"
1538+
}
1539+
],
1540+
"source": [
1541+
"mc = YaResnet50()\n",
1542+
"mc"
1543+
]
14291544
}
14301545
],
14311546
"metadata": {

0 commit comments

Comments
 (0)