forked from qubvel-org/segmentation_models.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
93 lines (84 loc) · 2.07 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from . import datasets
from . import encoders
from . import decoders
from . import losses
from . import metrics
from .decoders.unet import Unet
from .decoders.unetplusplus import UnetPlusPlus
from .decoders.manet import MAnet
from .decoders.linknet import Linknet
from .decoders.fpn import FPN
from .decoders.pspnet import PSPNet
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .decoders.dpt import DPT
from .base.hub_mixin import from_pretrained
from .__version__ import __version__
# some private imports for create_model function
from typing import Optional as _Optional
import torch as _torch
_MODEL_ARCHITECTURES = [
Unet,
UnetPlusPlus,
MAnet,
Linknet,
FPN,
PSPNet,
DeepLabV3,
DeepLabV3Plus,
PAN,
UPerNet,
Segformer,
DPT,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
def create_model(
arch: str,
encoder_name: str = "resnet34",
encoder_weights: _Optional[str] = "imagenet",
in_channels: int = 3,
classes: int = 1,
**kwargs,
) -> _torch.nn.Module:
"""Models entrypoint, allows to create any model architecture just with
parameters, without using its class
"""
try:
model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()]
except KeyError:
raise KeyError(
"Wrong architecture type `{}`. Available options are: {}".format(
arch, list(MODEL_ARCHITECTURES_MAPPING.keys())
)
)
return model_class(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
**kwargs,
)
__all__ = [
"datasets",
"encoders",
"decoders",
"losses",
"metrics",
"Unet",
"UnetPlusPlus",
"MAnet",
"Linknet",
"FPN",
"PSPNet",
"DeepLabV3",
"DeepLabV3Plus",
"PAN",
"UPerNet",
"Segformer",
"DPT",
"from_pretrained",
"create_model",
"__version__",
]