Skip to content

Commit be40e83

Browse files
committed
Support for timm_3d models
1 parent c4ce774 commit be40e83

File tree

4 files changed

+31
-3
lines changed

4 files changed

+31
-3
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,20 @@ Note: In the official github repo the s0 variant has additional num_conv_branche
198198
</div>
199199
</details>
200200

201+
### Timm 3D encoders
202+
203+
We now support encoders from [timm_3d](https://github.com/ZFTurbo/timm_3d) library. Full list available [here](https://github.com/ZFTurbo/timm_3d/blob/main/docs/models_list.md). To use them add `tu-` before encoder name.
204+
Example:
205+
206+
```python
207+
encoder_name = 'tu-maxvit_base_tf_224.in21k'
208+
model = smp.Unet(
209+
encoder_name=encoder_name,
210+
encoder_weights=None,
211+
in_channels=3,
212+
classes=1,
213+
)
214+
```
201215

202216
## Notes for 3D version
203217

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ torchvision>=0.5.0
33
pretrainedmodels==0.7.4
44
efficientnet-pytorch==0.7.1
55
timm==0.9.7
6+
timm_3d==1.0.1
67
tqdm
78
pillow
89
six

segmentation_models_pytorch_3d/encoders/timm_universal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import timm
1+
import timm_3d
22
import torch.nn as nn
33

44

@@ -17,7 +17,7 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=
1717
if output_stride == 32:
1818
kwargs.pop("output_stride")
1919

20-
self.model = timm.create_model(name, **kwargs)
20+
self.model = timm_3d.create_model(name, **kwargs)
2121

2222
self._in_channels = in_channels
2323
self._out_channels = [

test.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,17 @@
244244
strides=((1, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
245245
)
246246
o = model(torch.randn(2, 3, 32, 64, 128))
247-
print(f'Result shape: {o.shape}')
247+
print(f'Result shape: {o.shape}')
248+
249+
if 1:
250+
encoder_name = 'tu-maxvit_base_tf_224.in21k'
251+
print('Test Timm 3d model: {}...'.format(encoder_name))
252+
print('Go for {}'.format(encoder_name))
253+
model = smp.Unet(
254+
encoder_name=encoder_name,
255+
encoder_weights=None,
256+
in_channels=3,
257+
classes=1,
258+
)
259+
o = model(torch.randn(2, 3, 128, 64, 64))
260+
print(f'Result shape: {o.shape}')

0 commit comments

Comments
 (0)