-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding DPT #1079
base: main
Are you sure you want to change the base?
[WIP] Adding DPT #1079
Conversation
Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested |
Hi @qubvel. Sure, I will set up the relevant tests. |
Hi @qubvel, did some refactoring. This commit should pass majority of the tests now. Had a few questions -
Some issues I faced with the default environment for the smp development.
|
Hi @qubvel. Just a gentle reminder, I think this PR is ready for review. |
Hey @vedantdalimkar, thanks for the ping, and sorry for the delay! Will try to do a pass today or on Monday. Thank you for your patience 🤗 |
Codecov ReportAttention: Patch coverage is
... and 2 files with indirect coverage changes 🚀 New features to boost your workflow:
|
def load_state_dict(self, state_dict, **kwargs): | ||
# for compatibility of weights for | ||
# timm- ported encoders with TimmUniversalEncoder | ||
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] | ||
|
||
is_deprecated_encoder = any( | ||
self.name.startswith(pattern) for pattern in patterns | ||
) | ||
|
||
if is_deprecated_encoder: | ||
keys = list(state_dict.keys()) | ||
for key in keys: | ||
new_key = key | ||
if not key.startswith("model."): | ||
new_key = "model." + key | ||
if "gernet" in self.name: | ||
new_key = new_key.replace(".stages.", ".stages_") | ||
state_dict[new_key] = state_dict.pop(key) | ||
|
||
return super().load_state_dict(state_dict, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed for this kind of encoder
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
Merge two dictionaries, ensuring no duplicate keys exist. | ||
|
||
Args: | ||
a (dict): Base dictionary. | ||
b (dict): Additional parameters to merge. | ||
|
||
Returns: | ||
dict: A merged dictionary. | ||
""" | ||
duplicates = a.keys() & b.keys() | ||
if duplicates: | ||
raise ValueError(f"'{duplicates}' already specified internally") | ||
|
||
return a | b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be imported if needed, no need to duplicate code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀
Here is what's missing:
- Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see
examples/
folder) - We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
- Refine tests a bit (see comment below)
Other than that it looks clean, thank you for your hard work 🙌
intermediates_only=True, | ||
) | ||
|
||
cls_tokens = [None] * len(self.out_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to provide CLS tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DPT architecture requires CLS tokens. The motivation for the same is provided in the DPT paper -
... the readout token doesn’t serve a clear purpose for the task of dense prediction, but could potentially still be useful to capture and distribute global information.
The default setting of DPT architecture broadcasts the CLS token and adds it to the patch features along the feature dimension before projecting the patch features back to the original feature dimension.
This is why CLS token is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks!
tests/models/test_dpt.py
Outdated
@property | ||
def model_class(self): | ||
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type] | ||
|
||
@property | ||
def decoder_channels(self): | ||
signature = inspect.signature(self.model_class) | ||
# check if decoder_channels is in the signature | ||
if "decoder_channels" in signature.parameters: | ||
return signature.parameters["decoder_channels"].default | ||
return None | ||
|
||
@lru_cache | ||
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): | ||
batch_size = batch_size or self.default_batch_size | ||
num_channels = num_channels or self.default_num_channels | ||
height = height or self.default_height | ||
width = width or self.default_width | ||
return torch.rand(batch_size, num_channels, height, width) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override this, the same as in base class
tests/models/test_dpt.py
Outdated
def test_forward_backward(self): | ||
sample = self._get_sample().to(default_device) | ||
|
||
model = self.get_default_model() | ||
|
||
# check default in_channels=3 | ||
output = model(sample) | ||
|
||
# check default output number of classes = 1 | ||
expected_number_of_classes = 1 | ||
result_number_of_classes = output.shape[1] | ||
self.assertEqual( | ||
result_number_of_classes, | ||
expected_number_of_classes, | ||
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", | ||
) | ||
|
||
# check backward pass | ||
output.mean().backward() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override
def test_forward_backward(self): | |
sample = self._get_sample().to(default_device) | |
model = self.get_default_model() | |
# check default in_channels=3 | |
output = model(sample) | |
# check default output number of classes = 1 | |
expected_number_of_classes = 1 | |
result_number_of_classes = output.shape[1] | |
self.assertEqual( | |
result_number_of_classes, | |
expected_number_of_classes, | |
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", | |
) | |
# check backward pass | |
output.mean().backward() |
tests/models/test_dpt.py
Outdated
@pytest.mark.torch_export | ||
def test_torch_export(self): | ||
if not check_run_test_on_diff_or_main(self.files_for_diff): | ||
self.skipTest("No diff and not on `main`.") | ||
|
||
sample = self._get_sample().to(default_device) | ||
model = self.get_default_model() | ||
model.eval() | ||
|
||
exported_model = torch.export.export( | ||
model, | ||
args=(sample,), | ||
strict=True, | ||
) | ||
|
||
with torch.inference_mode(): | ||
eager_output = model(sample) | ||
exported_output = exported_model.module().forward(sample) | ||
|
||
self.assertEqual(eager_output.shape, exported_output.shape) | ||
torch.testing.assert_close(eager_output, exported_output) | ||
|
||
@pytest.mark.torch_script | ||
def test_torch_script(self): | ||
if not check_run_test_on_diff_or_main(self.files_for_diff): | ||
self.skipTest("No diff and not on `main`.") | ||
|
||
sample = self._get_sample().to(default_device) | ||
model = self.get_default_model() | ||
model.eval() | ||
|
||
if not model._is_torch_scriptable: | ||
with self.assertRaises(RuntimeError): | ||
scripted_model = torch.jit.script(model) | ||
return | ||
|
||
scripted_model = torch.jit.script(model) | ||
|
||
with torch.inference_mode(): | ||
scripted_output = scripted_model(sample) | ||
eager_output = model(sample) | ||
|
||
self.assertEqual(scripted_output.shape, eager_output.shape) | ||
torch.testing.assert_close(scripted_output, eager_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override, please leave only modified tests here, other tests will be fetched from base class
tests/models/test_dpt.py
Outdated
model = smp.create_model( | ||
self.model_type, self.test_encoder_name, output_stride=None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid passing output_stride=None
to make it consistent with other models? It can be handled on create_model
level or encoder level for example
Hey @qubvel , I had a concern regarding the following point.
The DPT architecture has a different segmentation head compared to the standard SMP segmentation head. I am guessing, I would need to keep the head same as the original architecture in order to ensure that logits match with the original model. Is it fine if the smp.DPT model has a different segmentation head? If yes, I should include the new segmentation head class under decoders/dpt/head.py right? Edit: In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output? |
Yes, it's fine to have it's own head, you can put it in
You can run HF model on all ones tensor to get expected output, then use it in the test, please see Segformer test segmentation_models.pytorch/tests/models/test_segformer.py Lines 16 to 43 in 4aa36c6
|
I am still confused regarding a couple of things -
i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ? |
Any of those, should be equivalent
Yes, you can upload it to your own HF account, I will transfer it to |
Hey @qubvel, I have worked on the points mentioned above, hopefully should pass the required tests now (except torch.compile test)
Sample code which I have used for logit generation from original model - Code
I have also provided the model conversion script in the misc folder. I have uploaded the model after weight conversion here - https://huggingface.co/vedantdalimkar/DPT
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments! A made another pass 🤗
encoders_table.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove this file
timm_encoders.txt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice table! Absolutely love it, thanks for working on it. Can you add it into the docs/ ?
.gitignore
Outdated
# model weight folder | ||
dpt_large-ade20k-b12dca68 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use .git/info/exclude
to ignore it locally
# model weight folder | |
dpt_large-ade20k-b12dca68 |
@@ -28,6 +28,15 @@ def slow_test(test_case): | |||
return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case) | |||
|
|||
|
|||
def requires_timm_greater_or_equal(version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
tests/models/test_dpt.py
Outdated
if not model._is_torch_compilable: | ||
with self.assertRaises(RuntimeError): | ||
torch.compiler.reset() | ||
compiled_model = torch.compile( | ||
model, fullgraph=True, dynamic=True, backend="eager" | ||
) | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add this to the base class and not override here
smp_model.save_pretrained(model_name) | ||
|
||
repo_id = HF_HUB_PATH | ||
api = huggingface_hub.HfApi() | ||
api.create_repo(repo_id=repo_id, repo_type="model") | ||
api.upload_folder(folder_path=model_name, repo_id=repo_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use smp_model.push_to_hub(...) instead
|
||
except Exception: | ||
continue | ||
try: | ||
if valid_vit_encoder_for_dpt(name): | ||
supported_models[name] = dict(supported_only_for_dpt=True) | ||
except Exception: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check only if we got an exception here?
Would it be better to make two independent checks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you check the behaviour of functions check_features_and_reduction
and valid_vit_encoder_for_dpt
, their output is mutually exclusive. To be more detailed:
check_features_and_reduction
returns true only when reduction scales of a model are equal to[2, 4, 8, 16, 32]
, whereas,valid_vit_encoder_for_dpt
returns false if the encoder has multiple reduction scales.
In short, a model which satisfies the conditions specified by check_features_and_reduction
will never satisfy the conditions set by valid_vit_encoder_for_dpt
and vice versa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported
Should I update this as well or will you do it from your end?
_is_torch_scriptable = False | ||
_is_torch_exportable = True | ||
_is_torch_compilable = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a bit strange that timm encoder is not compilable, but ok..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Timm encoders are indeed compilable as you said but I am using some conditional logic in the forward method of the encoder which introduces graph breaks. I guess that may be a reason the encoder is not compilable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, no problem, I can take a look later, when the model is merged, if it's possible to avoid graph break
**self.default_encoder_kwargs, | ||
) | ||
|
||
@requires_timm_greater_or_equal("1.0.15") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a comment on why we should have 1.0.15?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the latest features of VisionTransformer class in timm that are required for SMP were introduced after the timm 1.0.x version, but it was tedious to pinpoint the exact version number so I just kept the latest timm version to be safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here in case test it the same, let's use inheritance and not duplicate the code. To add a decorator the following can be used:
@requires_timm_greater_or_equal("1.0.15")
def test_in_channels(self):
return super().test_in_channels()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the test - test_in_channels
is different from the base class. It has additional line of code on line 91 in test_timm_vit_encoders.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's an example, just to make sure the code is not duplicated 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry for the misunderstanding.
Hi @qubvel, bumping as a gentle reminder for another review. |
Tried to address issue #1073
I have 1 concern: Right now the model only works for images having the same resolution as the original input resolution on which a particular ViT encoder was trained. Is this behaviour okay or should I change the functionality so that a dynamic image resolution is supported?
I have tested various ViT encoders at different encoder depths and the model seems to run correctly.
@qubvel Please let me know if you feel I should make any changes.