diff --git a/segmentation_models_pytorch/decoders/mask2former/__init__.py b/segmentation_models_pytorch/decoders/mask2former/__init__.py new file mode 100644 index 00000000..ef576370 --- /dev/null +++ b/segmentation_models_pytorch/decoders/mask2former/__init__.py @@ -0,0 +1,3 @@ +from .model import Mask2Former + +__all__ = [Mask2Former] diff --git a/segmentation_models_pytorch/decoders/mask2former/decoder.py b/segmentation_models_pytorch/decoders/mask2former/decoder.py new file mode 100644 index 00000000..914d4013 --- /dev/null +++ b/segmentation_models_pytorch/decoders/mask2former/decoder.py @@ -0,0 +1,11 @@ +import torch.nn as nn + + +class Mask2FormerPixelModule(nn.Module): + def __init__(self): + super().__init__() + + +class Mask2FormerTransformerModule(nn.Module): + def __init__(self): + super().__init__() diff --git a/segmentation_models_pytorch/decoders/mask2former/model.py b/segmentation_models_pytorch/decoders/mask2former/model.py new file mode 100644 index 00000000..895917b3 --- /dev/null +++ b/segmentation_models_pytorch/decoders/mask2former/model.py @@ -0,0 +1,15 @@ +from segmentation_models_pytorch.base import ( + SegmentationModel, +) + +from .decoder import Mask2FormerPixelModule, Mask2FormerTransformerModule + + +class Mask2Former(SegmentationModel): + def __init__(self): + super().__init__() + pixel_module = Mask2FormerPixelModule() + transformer_module = Mask2FormerTransformerModule() + + def forward(self, x): + return x