|
18 | 18 | "outputs": [],
|
19 | 19 | "source": [
|
20 | 20 | "#hide\n",
|
21 |
| - "import torch" |
| 21 | + "import torch\n", |
| 22 | + "from torch import nn" |
22 | 23 | ]
|
23 | 24 | },
|
24 | 25 | {
|
|
1326 | 1327 | "metadata": {},
|
1327 | 1328 | "outputs": [],
|
1328 | 1329 | "source": [
|
| 1330 | + "mc = ModelConstructor(name=\"YaResNet\")\n", |
1329 | 1331 | "mc.block = YaResBlock"
|
1330 | 1332 | ]
|
1331 | 1333 | },
|
| 1334 | + { |
| 1335 | + "attachments": {}, |
| 1336 | + "cell_type": "markdown", |
| 1337 | + "metadata": {}, |
| 1338 | + "source": [ |
| 1339 | + "Or in one line:" |
| 1340 | + ] |
| 1341 | + }, |
| 1342 | + { |
| 1343 | + "cell_type": "code", |
| 1344 | + "execution_count": null, |
| 1345 | + "metadata": {}, |
| 1346 | + "outputs": [], |
| 1347 | + "source": [ |
| 1348 | + "mc = ModelConstructor(name=\"YaResNet\", block=YaResBlock)" |
| 1349 | + ] |
| 1350 | + }, |
1332 | 1351 | {
|
1333 | 1352 | "cell_type": "markdown",
|
1334 | 1353 | "metadata": {},
|
|
1352 | 1371 | " block='YaResBlock'\n",
|
1353 | 1372 | " conv_layer='ConvBnAct'\n",
|
1354 | 1373 | " block_sizes=[64, 128, 256, 512]\n",
|
1355 |
| - " layers=[3, 4, 6, 3]\n", |
| 1374 | + " layers=[2, 2, 2, 2]\n", |
1356 | 1375 | " norm='BatchNorm2d'\n",
|
1357 |
| - " act_fn='Mish'\n", |
| 1376 | + " act_fn='ReLU'\n", |
1358 | 1377 | " pool=\"AvgPool2d {'kernel_size': 2, 'ceil_mode': True}\"\n",
|
1359 |
| - " expansion=4\n", |
| 1378 | + " expansion=1\n", |
1360 | 1379 | " groups=1\n",
|
1361 | 1380 | " bn_1st=True\n",
|
1362 | 1381 | " zero_bn=True\n",
|
1363 |
| - " stem_sizes=[32, 64, 64]\n", |
| 1382 | + " stem_sizes=[32, 32, 64]\n", |
1364 | 1383 | " stem_pool=\"MaxPool2d {'kernel_size': 3, 'stride': 2, 'padding': 1}\"\n",
|
1365 | 1384 | " init_cnn='init_cnn'\n",
|
1366 | 1385 | " make_stem='make_stem'\n",
|
|
1372 | 1391 | ],
|
1373 | 1392 | "source": [
|
1374 | 1393 | "#collapse_output\n",
|
1375 |
| - "mc.name = 'YaResNet'\n", |
1376 | 1394 | "mc.print_cfg()"
|
1377 | 1395 | ]
|
1378 | 1396 | },
|
|
1395 | 1413 | " (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
|
1396 | 1414 | " (convs): Sequential(\n",
|
1397 | 1415 | " (conv_0): ConvBnAct(\n",
|
1398 |
| - " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
| 1416 | + " (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", |
1399 | 1417 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
1400 |
| - " (act_fn): Mish(inplace=True)\n", |
| 1418 | + " (act_fn): ReLU(inplace=True)\n", |
1401 | 1419 | " )\n",
|
1402 | 1420 | " (conv_1): ConvBnAct(\n",
|
1403 | 1421 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
1404 | 1422 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
1405 |
| - " (act_fn): Mish(inplace=True)\n", |
1406 |
| - " )\n", |
1407 |
| - " (conv_2): ConvBnAct(\n", |
1408 |
| - " (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
1409 |
| - " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
1410 | 1423 | " )\n",
|
1411 | 1424 | " )\n",
|
1412 | 1425 | " (id_conv): ConvBnAct(\n",
|
1413 |
| - " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
1414 |
| - " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
| 1426 | + " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
| 1427 | + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
1415 | 1428 | " )\n",
|
1416 |
| - " (merge): Mish(inplace=True)\n", |
| 1429 | + " (merge): ReLU(inplace=True)\n", |
1417 | 1430 | ")"
|
1418 | 1431 | ]
|
1419 | 1432 | },
|
|
1426 | 1439 | "#collapse_output\n",
|
1427 | 1440 | "mc.body.l_1.bl_0"
|
1428 | 1441 | ]
|
| 1442 | + }, |
| 1443 | + { |
| 1444 | + "attachments": {}, |
| 1445 | + "cell_type": "markdown", |
| 1446 | + "metadata": {}, |
| 1447 | + "source": [ |
| 1448 | + "Lets create `Resnet34` like model constructor:" |
| 1449 | + ] |
| 1450 | + }, |
| 1451 | + { |
| 1452 | + "cell_type": "code", |
| 1453 | + "execution_count": null, |
| 1454 | + "metadata": {}, |
| 1455 | + "outputs": [], |
| 1456 | + "source": [ |
| 1457 | + "class YaResnet34(ModelConstructor):\n", |
| 1458 | + " block: type[nn.Module] = YaResBlock\n", |
| 1459 | + " layers: list[int] = [3, 4, 6, 3]" |
| 1460 | + ] |
| 1461 | + }, |
| 1462 | + { |
| 1463 | + "cell_type": "code", |
| 1464 | + "execution_count": null, |
| 1465 | + "metadata": {}, |
| 1466 | + "outputs": [ |
| 1467 | + { |
| 1468 | + "name": "stdout", |
| 1469 | + "output_type": "stream", |
| 1470 | + "text": [ |
| 1471 | + "YaResnet34(\n", |
| 1472 | + " in_chans=3\n", |
| 1473 | + " num_classes=1000\n", |
| 1474 | + " block='YaResBlock'\n", |
| 1475 | + " conv_layer='ConvBnAct'\n", |
| 1476 | + " block_sizes=[64, 128, 256, 512]\n", |
| 1477 | + " layers=[3, 4, 6, 3]\n", |
| 1478 | + " norm='BatchNorm2d'\n", |
| 1479 | + " act_fn='ReLU'\n", |
| 1480 | + " pool=\"AvgPool2d {'kernel_size': 2, 'ceil_mode': True}\"\n", |
| 1481 | + " expansion=1\n", |
| 1482 | + " groups=1\n", |
| 1483 | + " bn_1st=True\n", |
| 1484 | + " zero_bn=True\n", |
| 1485 | + " stem_sizes=[32, 32, 64]\n", |
| 1486 | + " stem_pool=\"MaxPool2d {'kernel_size': 3, 'stride': 2, 'padding': 1}\"\n", |
| 1487 | + " init_cnn='init_cnn'\n", |
| 1488 | + " make_stem='make_stem'\n", |
| 1489 | + " make_layer='make_layer'\n", |
| 1490 | + " make_body='make_body'\n", |
| 1491 | + " make_head='make_head')\n" |
| 1492 | + ] |
| 1493 | + } |
| 1494 | + ], |
| 1495 | + "source": [ |
| 1496 | + "mc = YaResnet34()\n", |
| 1497 | + "mc.print_cfg()" |
| 1498 | + ] |
| 1499 | + }, |
| 1500 | + { |
| 1501 | + "attachments": {}, |
| 1502 | + "cell_type": "markdown", |
| 1503 | + "metadata": {}, |
| 1504 | + "source": [ |
| 1505 | + "And `Resnet50` like model can be inherited from `YaResnet34`:" |
| 1506 | + ] |
| 1507 | + }, |
| 1508 | + { |
| 1509 | + "cell_type": "code", |
| 1510 | + "execution_count": null, |
| 1511 | + "metadata": {}, |
| 1512 | + "outputs": [], |
| 1513 | + "source": [ |
| 1514 | + "class YaResnet50(YaResnet34):\n", |
| 1515 | + " expansion = 4" |
| 1516 | + ] |
| 1517 | + }, |
| 1518 | + { |
| 1519 | + "cell_type": "code", |
| 1520 | + "execution_count": null, |
| 1521 | + "metadata": {}, |
| 1522 | + "outputs": [ |
| 1523 | + { |
| 1524 | + "data": { |
| 1525 | + "text/plain": [ |
| 1526 | + "YaResnet50\n", |
| 1527 | + " in_chans: 3, num_classes: 1000\n", |
| 1528 | + " expansion: 4, groups: 1, dw: False, div_groups: None\n", |
| 1529 | + " act_fn: ReLU, sa: False, se: False\n", |
| 1530 | + " stem sizes: [32, 32, 64], stride on 0\n", |
| 1531 | + " body sizes [64, 128, 256, 512]\n", |
| 1532 | + " layers: [3, 4, 6, 3]" |
| 1533 | + ] |
| 1534 | + }, |
| 1535 | + "execution_count": null, |
| 1536 | + "metadata": {}, |
| 1537 | + "output_type": "execute_result" |
| 1538 | + } |
| 1539 | + ], |
| 1540 | + "source": [ |
| 1541 | + "mc = YaResnet50()\n", |
| 1542 | + "mc" |
| 1543 | + ] |
1429 | 1544 | }
|
1430 | 1545 | ],
|
1431 | 1546 | "metadata": {
|
|
0 commit comments