forked from qubvel-org/segmentation_models.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_fpn.py
29 lines (24 loc) · 1000 Bytes
/
test_fpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import segmentation_models_pytorch as smp
from tests.models import base
class TestFpnModel(base.BaseModelTester):
test_model_type = "fpn"
files_for_diff = [r"decoders/fpn/", r"base/"]
def test_interpolation(self):
# test bilinear
model_1 = smp.create_model(
self.test_model_type,
self.test_encoder_name,
decoder_interpolation="bilinear",
)
assert model_1.decoder.p2.interpolation_mode == "bilinear"
assert model_1.decoder.p3.interpolation_mode == "bilinear"
assert model_1.decoder.p4.interpolation_mode == "bilinear"
# test bicubic
model_2 = smp.create_model(
self.test_model_type,
self.test_encoder_name,
decoder_interpolation="bicubic",
)
assert model_2.decoder.p2.interpolation_mode == "bicubic"
assert model_2.decoder.p3.interpolation_mode == "bicubic"
assert model_2.decoder.p4.interpolation_mode == "bicubic"