Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding DPT #1079

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Conversation

vedantdalimkar
Copy link

@vedantdalimkar vedantdalimkar commented Mar 2, 2025

Tried to address issue #1073

  1. Added a TimmViTEncoder class in the encoder package to support ViT models for encoder.
  2. Added DPT model architecture.

I have 1 concern: Right now the model only works for images having the same resolution as the original input resolution on which a particular ViT encoder was trained. Is this behaviour okay or should I change the functionality so that a dynamic image resolution is supported?

I have tested various ViT encoders at different encoder depths and the model seems to run correctly.
@qubvel Please let me know if you feel I should make any changes.

@qubvel
Copy link
Collaborator

qubvel commented Mar 3, 2025

Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested

@vedantdalimkar
Copy link
Author

Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested

Hi @qubvel. Sure, I will set up the relevant tests.

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 8, 2025

Hi @qubvel, did some refactoring. This commit should pass majority of the tests now. Had a few questions -

  1. The DPT model doesn't seem to torch compilable and scriptable since it has graph breaks. Should I skip those tests?
  2. Currently, the DPT model is not on HF hub which is required for the test_preserve_forward_output test, should I skip this test as well?

Some issues I faced with the default environment for the smp development.

  1. The TimmViTEncoder class that I have added requires the latest version of timm, so I have decorated all test functions with requires_timm_greater_or_equal function (similar to requires_torch_greater_or_equal). If possible, please use the latest timm version in the requirements so that you can run these tests or let me know if you want me to change this behaviour.
  2. The GeLU activation function used in the decoder requires torch version >= 2.0. However, the requirements have a torch version lesser than this. Can the requirements be updated?

@vedantdalimkar
Copy link
Author

Hi @qubvel. Just a gentle reminder, I think this PR is ready for review.

@qubvel
Copy link
Collaborator

qubvel commented Mar 14, 2025

Hey @vedantdalimkar, thanks for the ping, and sorry for the delay! Will try to do a pass today or on Monday. Thank you for your patience 🤗

@qubvel qubvel self-requested a review March 18, 2025 09:50
Copy link

codecov bot commented Mar 18, 2025

Codecov Report

Attention: Patch coverage is 92.50000% with 18 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
segmentation_models_pytorch/encoders/timm_vit.py 82.66% 13 Missing ⚠️
...egmentation_models_pytorch/decoders/dpt/decoder.py 95.90% 5 Missing ⚠️
Files with missing lines Coverage Δ
segmentation_models_pytorch/__init__.py 93.10% <100.00%> (+0.24%) ⬆️
...gmentation_models_pytorch/decoders/dpt/__init__.py 100.00% <100.00%> (ø)
segmentation_models_pytorch/decoders/dpt/model.py 100.00% <100.00%> (ø)
segmentation_models_pytorch/encoders/__init__.py 76.00% <100.00%> (+1.00%) ⬆️
...egmentation_models_pytorch/decoders/dpt/decoder.py 95.90% <95.90%> (ø)
segmentation_models_pytorch/encoders/timm_vit.py 82.66% <82.66%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 209 to 228
def load_state_dict(self, state_dict, **kwargs):
# for compatibility of weights for
# timm- ported encoders with TimmUniversalEncoder
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]

is_deprecated_encoder = any(
self.name.startswith(pattern) for pattern in patterns
)

if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if not key.startswith("model."):
new_key = "model." + key
if "gernet" in self.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)

return super().load_state_dict(state_dict, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed for this kind of encoder

Comment on lines 231 to 246
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
"""
Merge two dictionaries, ensuring no duplicate keys exist.

Args:
a (dict): Base dictionary.
b (dict): Additional parameters to merge.

Returns:
dict: A merged dictionary.
"""
duplicates = a.keys() & b.keys()
if duplicates:
raise ValueError(f"'{duplicates}' already specified internally")

return a | b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be imported if needed, no need to duplicate code

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀

Here is what's missing:

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)
  2. We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
  3. Refine tests a bit (see comment below)

Other than that it looks clean, thank you for your hard work 🙌

intermediates_only=True,
)

cls_tokens = [None] * len(self.out_indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to provide CLS tokens?

Copy link
Author

@vedantdalimkar vedantdalimkar Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DPT architecture requires CLS tokens. The motivation for the same is provided in the DPT paper -

... the readout token doesn’t serve a clear purpose for the task of dense prediction, but could potentially still be useful to capture and distribute global information.

The default setting of DPT architecture broadcasts the CLS token and adds it to the patch features along the feature dimension before projecting the patch features back to the original feature dimension.

This is why CLS token is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

Comment on lines 32 to 51
@property
def model_class(self):
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type]

@property
def decoder_channels(self):
signature = inspect.signature(self.model_class)
# check if decoder_channels is in the signature
if "decoder_channels" in signature.parameters:
return signature.parameters["decoder_channels"].default
return None

@lru_cache
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None):
batch_size = batch_size or self.default_batch_size
num_channels = num_channels or self.default_num_channels
height = height or self.default_height
width = width or self.default_width
return torch.rand(batch_size, num_channels, height, width)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override this, the same as in base class

Comment on lines 60 to 79
def test_forward_backward(self):
sample = self._get_sample().to(default_device)

model = self.get_default_model()

# check default in_channels=3
output = model(sample)

# check default output number of classes = 1
expected_number_of_classes = 1
result_number_of_classes = output.shape[1]
self.assertEqual(
result_number_of_classes,
expected_number_of_classes,
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}",
)

# check backward pass
output.mean().backward()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override

Suggested change
def test_forward_backward(self):
sample = self._get_sample().to(default_device)
model = self.get_default_model()
# check default in_channels=3
output = model(sample)
# check default output number of classes = 1
expected_number_of_classes = 1
result_number_of_classes = output.shape[1]
self.assertEqual(
result_number_of_classes,
expected_number_of_classes,
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}",
)
# check backward pass
output.mean().backward()

Comment on lines 233 to 276
@pytest.mark.torch_export
def test_torch_export(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample().to(default_device)
model = self.get_default_model()
model.eval()

exported_model = torch.export.export(
model,
args=(sample,),
strict=True,
)

with torch.inference_mode():
eager_output = model(sample)
exported_output = exported_model.module().forward(sample)

self.assertEqual(eager_output.shape, exported_output.shape)
torch.testing.assert_close(eager_output, exported_output)

@pytest.mark.torch_script
def test_torch_script(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample().to(default_device)
model = self.get_default_model()
model.eval()

if not model._is_torch_scriptable:
with self.assertRaises(RuntimeError):
scripted_model = torch.jit.script(model)
return

scripted_model = torch.jit.script(model)

with torch.inference_mode():
scripted_output = scripted_model(sample)
eager_output = model(sample)

self.assertEqual(scripted_output.shape, eager_output.shape)
torch.testing.assert_close(scripted_output, eager_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override, please leave only modified tests here, other tests will be fetched from base class

Comment on lines 54 to 56
model = smp.create_model(
self.model_type, self.test_encoder_name, output_stride=None
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid passing output_stride=None to make it consistent with other models? It can be handled on create_model level or encoder level for example

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 18, 2025

Hey @qubvel , I had a concern regarding the following point.

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)

The DPT architecture has a different segmentation head compared to the standard SMP segmentation head. I am guessing, I would need to keep the head same as the original architecture in order to ensure that logits match with the original model. Is it fine if the smp.DPT model has a different segmentation head?

If yes, I should include the new segmentation head class under decoders/dpt/head.py right?

Edit: In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output?

@qubvel
Copy link
Collaborator

qubvel commented Mar 18, 2025

Is it fine if the smp.DPT model has a different segmentation head?
If yes, I should include the new segmentation head class under decoders/dpt/head.py right?

Yes, it's fine to have it's own head, you can put it in decoder.py

In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output?

You can run HF model on all ones tensor to get expected output, then use it in the test, please see Segformer test

def test_load_pretrained(self):
hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k"
model = smp.from_pretrained(hub_checkpoint)
model = model.eval().to(default_device)
sample = torch.ones([1, 3, 512, 512]).to(default_device)
with torch.inference_mode():
output = model(sample)
self.assertEqual(output.shape, (1, 150, 512, 512))
expected_logits_slice = torch.tensor(
[-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157]
)
resulted_logits_slice = output[0, 0, 256, :6].cpu()
is_equal = torch.allclose(
expected_logits_slice, resulted_logits_slice, atol=1e-2
)
max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice))
self.assertTrue(
is_equal,
f"Expected logits slice and resulted logits slice are not equal.\n"
f"Max diff: {max_diff}\n"
f"Expected: {expected_logits_slice}\n"
f"Resulted: {resulted_logits_slice}\n",
)

@vedantdalimkar
Copy link
Author

I am still confused regarding a couple of things -

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)

i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ?
ii) Right now, smp-hub doesn't have any model for DPT, is it supposed to uploaded by us?

@qubvel
Copy link
Collaborator

qubvel commented Mar 18, 2025

i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ?

Any of those, should be equivalent

ii) Right now, smp-hub doesn't have any model for DPT, is it supposed to uploaded by us?

Yes, you can upload it to your own HF account, I will transfer it to smp-hub as soon as PR is ready 🤗

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 22, 2025

Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀

Here is what's missing:

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)
  2. We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
  3. Refine tests a bit (see comment below)

Other than that it looks clean, thank you for your hard work 🙌

Hey @qubvel, I have worked on the points mentioned above, hopefully should pass the required tests now (except torch.compile test)

  1. For generating logits of original model, I have used the segmentation model in the original DPT repository. I have used the dpt_large-ade20k-b12dca68.pt checkpoint for weight conversion.

Sample code which I have used for logit generation from original model -

Code
  from dpt.models import DPTSegmentationModel
  import torch
  input = torch.ones((1,3,384,384))
  model = DPTSegmentationModel(num_classes = 150,path = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt",backbone="vitl16_384",
          )
  
  model.eval()
  with torch.no_grad():
      output = model(input)

I have also provided the model conversion script in the misc folder. I have uploaded the model after weight conversion here - https://huggingface.co/vedantdalimkar/DPT

  1. I have added some logic in the generate_timm_tables.py script so that the table contains ViT-like encoders in timm which can be used with DPT.
    I faced quite a lot of issues while generating the smp encoders tablegenerate_table.py since I am working on a Windows OS. If possible, can you please generate the smp encoders table for me from your end?

  2. I have made the suggested changes for the tests.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments! A made another pass 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice table! Absolutely love it, thanks for working on it. Can you add it into the docs/ ?

.gitignore Outdated
Comment on lines 114 to 116
# model weight folder
dpt_large-ade20k-b12dca68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use .git/info/exclude to ignore it locally

Suggested change
# model weight folder
dpt_large-ade20k-b12dca68

@@ -28,6 +28,15 @@ def slow_test(test_case):
return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case)


def requires_timm_greater_or_equal(version: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 37 to 43
if not model._is_torch_compilable:
with self.assertRaises(RuntimeError):
torch.compiler.reset()
compiled_model = torch.compile(
model, fullgraph=True, dynamic=True, backend="eager"
)
return
Copy link
Collaborator

@qubvel qubvel Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this to the base class and not override here

Comment on lines 104 to 109
smp_model.save_pretrained(model_name)

repo_id = HF_HUB_PATH
api = huggingface_hub.HfApi()
api.create_repo(repo_id=repo_id, repo_type="model")
api.upload_folder(folder_path=model_name, repo_id=repo_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use smp_model.push_to_hub(...) instead

Comment on lines +96 to +102

except Exception:
continue
try:
if valid_vit_encoder_for_dpt(name):
supported_models[name] = dict(supported_only_for_dpt=True)
except Exception:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check only if we got an exception here?
Would it be better to make two independent checks?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the behaviour of functions check_features_and_reduction and valid_vit_encoder_for_dpt, their output is mutually exclusive. To be more detailed:

  1. check_features_and_reduction returns true only when reduction scales of a model are equal to [2, 4, 8, 16, 32], whereas,
  2. valid_vit_encoder_for_dpt returns false if the encoder has multiple reduction scales.

In short, a model which satisfies the conditions specified by check_features_and_reduction will never satisfy the conditions set by valid_vit_encoder_for_dpt and vice versa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Should I update this as well or will you do it from your end?

Comment on lines +20 to +22
_is_torch_scriptable = False
_is_torch_exportable = True
_is_torch_compilable = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit strange that timm encoder is not compilable, but ok..

Copy link
Author

@vedantdalimkar vedantdalimkar Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timm encoders are indeed compilable as you said but I am using some conditional logic in the forward method of the encoder which introduces graph breaks. I guess that may be a reason the encoder is not compilable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, no problem, I can take a look later, when the model is merged, if it's possible to avoid graph break

**self.default_encoder_kwargs,
)

@requires_timm_greater_or_equal("1.0.15")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment on why we should have 1.0.15?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the latest features of VisionTransformer class in timm that are required for SMP were introduced after the timm 1.0.x version, but it was tedious to pinpoint the exact version number so I just kept the latest timm version to be safe.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here in case test it the same, let's use inheritance and not duplicate the code. To add a decorator the following can be used:

@requires_timm_greater_or_equal("1.0.15")
    def test_in_channels(self):
        return super().test_in_channels()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the test - test_in_channels is different from the base class. It has additional line of code on line 91 in test_timm_vit_encoders.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's an example, just to make sure the code is not duplicated 🤗

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry for the misunderstanding.

@vedantdalimkar vedantdalimkar requested a review from qubvel March 27, 2025 15:52
@vedantdalimkar
Copy link
Author

Hi @qubvel, bumping as a gentle reminder for another review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants