Skip to content

[FEAT] Model loading refactor #10604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Feb 19, 2025
Merged

[FEAT] Model loading refactor #10604

merged 43 commits into from
Feb 19, 2025

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jan 17, 2025

What does this PR do?

Fixes #10013 . This PR refactors model loading in diffusers. Here's a list of major changes in this PR.

  • only two loading paths (low_cpu_mem_usage=True and low_cpu_mem_usage = False). We don't rely on load_checkpoint_and_dispatch anymore and we don't merge sharded checkpoint also.
  • support for sharded checkpoints for both loading paths
  • keep_module_in_fp32 support for sharded checkpoints
  • better support for displaying warning due to error/unexpected/missing/mismatched keys

For low_cpu_mem_usage = False:

  • Faster initialization (thanks to skipping the init + assign_to_params_buffers). I didn't benchmarked it but it should be as fast as low_cpu_mem_usage=True or maybe even faster. We did a similar PR in transformers thanks to @muellerzr.
  • Better torch_dtype support We don't initialize anymore the model in fp32 then cast the model to a specific dtype after finishing to load the weights.

For low_cpu_mem_usage = True or device_map!=None:

  • one path, we don't rely anymore on load_checkpoint_and_dispatch
  • device_map support for quantization
  • non persistance buffer support through dispatch_model ( the test you added is passing cc @hlky )

Single format file:

  • Simplified the single file format loading through from_pretrained. This way we have the same features as this function (device_map, quantization ...). Feel free to share your opinion @DN6, I didn't expect to touch this but I felt that we could simplify a bit

TODO (some items can be done in follow-up PRs):

  • Check if we have any regression / tests issues
  • Add more tests
  • Deal with missing keys in the model for both paths (before, it only worked when low_cpu_mem_usage=False since we are initializing the whole model)
  • Fix typing
  • Better support for offload with safetensors (like in transformers)

Please let me know your thoughts on the PR !

cc @sayakpaul, @DN6 , @yiyixuxu , @hlky , @a-r-r-o-w

@SunMarc SunMarc changed the title [FEAT ] Model loading refactor [FEAT] Model loading refactor Jan 17, 2025
@SunMarc
Copy link
Member Author

SunMarc commented Jan 18, 2025

FLAX CPU failing test is unrelated, failing in other PRs too

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this! Left some comments from a first pass.

I think we will need to also add tests for seeing if device_map works as expected for quantization. Okay to not test that a bit later once there is consensus about the design changes. Maybe we could add that as a TODO.

Other tests could include checking if we can do low_cpu_mem_usage=True along with some changed config values. This will ensure we're well tested for cases like #9343.

@sayakpaul
Copy link
Member

@SunMarc,

Additionally, I ran some tests on audace (two RTX 4090s). Some tests that are failing (they fail on main too):

Failures
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_0_hf_internal_testing_unet2d_sharded_dummy - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_0_hf_internal_testing_unet2d_sharded_dummy_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_with_variant - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument...

^^ passes when using with CUDA_VISIBLE_DEVICES=0 (same with main). Expected?

Same for following:

FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

And then I also ran:

RUN_SLOW=1 pytest tests/pipelines/stable_diffusion/test_stable_diffusion.py::StableDiffusionPipelineDeviceMapTests

Everything passes.


for param_name, param in named_buffers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep this or equivalent elsewhere, context: #10523

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes I did should also cover this use case. The test you added should pass with my PR. The is mainly due to adding the dispatch_model function at the end.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlky are we cool here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are passing so all good

@SunMarc
Copy link
Member Author

SunMarc commented Feb 14, 2025

I am running the 4bit quantization tests currently. And so far things are looking nice! Some tests that might be worth including/consdering:

Device map with quantization
Effectiveness of keep_modules_in_fp32 when not using quantization.

Done ! Please check

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽 Thanks @SunMarc. Just need to replace the model repo id in the keep_in_fp32 tests so they pass.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch, @SunMarc! I left some mixed comments across the board. But I don't think any of them are major.

Really great work!

@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 would you like to run the slow single-file tests to ensure we're not breaking anything here (if not already)?

Comment on lines 351 to 357
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.float32)
else:
self.assertTrue(module.weight.dtype == torch_dtype)
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  • Would it make sense to make this test a part of ModelTesterMixin? Not strongly opinionated about it.
  • Let's make sure we also perform inference -- this helps us validate the effectiveness of _keep_modules_in_fp32 even better. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to make this test a part of ModelTesterMixin? Not strongly opinionated about it.

It can make sense to add a test in ModelTesterMixin for the models that do have _keep_in_fp32_modules specified. However, for now, none of the models use this arg. Maybe in a follow-up PR when a model actually needs this ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure we also perform inference -- this helps us validate the effectiveness of _keep_modules_in_fp32 even better. WDYT?

fixed

@@ -136,7 +136,7 @@ def setUp(self):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to test with "auto" device_map instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would work also but I kept torch_device for simplicity as we do in transformers. Do you have multi-gpu runner for quantization tests ? We can create multi-gpu tests if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's do that in a separate PR. I can do that. DO you have a reference?

Copy link
Member Author

@SunMarc SunMarc Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc SunMarc requested a review from sayakpaul February 17, 2025 15:18
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some final set of questions and we should be good to go.


for param_name, param in named_buffers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlky are we cool here?

@@ -136,7 +136,7 @@ def setUp(self):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's do that in a separate PR. I can do that. DO you have a reference?

Comment on lines +302 to +317
expected_slice_auto = np.array(
[
0.34179688,
-0.03613281,
0.01428223,
-0.22949219,
-0.49609375,
0.4375,
-0.1640625,
-0.66015625,
0.43164062,
]
)
expected_slice_offload = np.array(
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changing because of device changes? Cc: @a-r-r-o-w for a double-check.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul
Copy link
Member

sayakpaul commented Feb 19, 2025

Failing test is unrelated and are being fixed by @hlky!

Thanks a lot @SunMarc!

@sayakpaul sayakpaul merged commit f5929e0 into main Feb 19, 2025
14 of 15 checks passed
@sayakpaul sayakpaul deleted the model-loading-refactor branch February 19, 2025 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[Core] refactor model loading
7 participants