Skip to content

Commit bdc58db

Browse files
committed
Fix typehints, improve error raising for encoders
1 parent aa705da commit bdc58db

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

segmentation_models_pytorch/encoders/__init__.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,23 @@
3737

3838

3939
def get_encoder(name, in_channels=3, depth=5, weights=None):
40-
Encoder = encoders[name]["encoder"]
40+
41+
try:
42+
Encoder = encoders[name]["encoder"]
43+
except KeyError:
44+
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
45+
4146
params = encoders[name]["params"]
4247
params.update(depth=depth)
4348
encoder = Encoder(**params)
4449

4550
if weights is not None:
46-
settings = encoders[name]["pretrained_settings"][weights]
51+
try:
52+
settings = encoders[name]["pretrained_settings"][weights]
53+
except KeyError:
54+
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Avaliable options are: {}".format(
55+
weights, name, list(encoders[name]["pretrained_settings"].keys()),
56+
))
4757
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
4858

4959
encoder.set_in_channels(in_channels)

segmentation_models_pytorch/pan/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PAN(SegmentationModel):
4444
def __init__(
4545
self,
4646
encoder_name: str = "resnet34",
47-
encoder_weights: str = "imagenet",
47+
encoder_weights: Optional[str] = "imagenet",
4848
encoder_dilation: bool = True,
4949
decoder_channels: int = 32,
5050
in_channels: int = 3,

segmentation_models_pytorch/unet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self,
5252
encoder_name: str = "resnet34",
5353
encoder_depth: int = 5,
54-
encoder_weights: str = "imagenet",
54+
encoder_weights: Optional[str] = "imagenet",
5555
decoder_use_batchnorm: bool = True,
5656
decoder_channels: List[int] = (256, 128, 64, 32, 16),
5757
decoder_attention_type: Optional[str] = None,

segmentation_models_pytorch/unetplusplus/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self,
5252
encoder_name: str = "resnet34",
5353
encoder_depth: int = 5,
54-
encoder_weights: str = "imagenet",
54+
encoder_weights: Optional[str] = "imagenet",
5555
decoder_use_batchnorm: bool = True,
5656
decoder_channels: List[int] = (256, 128, 64, 32, 16),
5757
decoder_attention_type: Optional[str] = None,

0 commit comments

Comments
 (0)