-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[FEAT] Model loading refactor #10604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
FLAX CPU failing test is unrelated, failing in other PRs too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for starting this! Left some comments from a first pass.
I think we will need to also add tests for seeing if device_map
works as expected for quantization. Okay to not test that a bit later once there is consensus about the design changes. Maybe we could add that as a TODO.
Other tests could include checking if we can do low_cpu_mem_usage=True
along with some changed config values. This will ensure we're well tested for cases like #9343.
Additionally, I ran some tests on FailuresFAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_0_hf_internal_testing_unet2d_sharded_dummy - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_0_hf_internal_testing_unet2d_sharded_dummy_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_with_variant - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument... ^^ passes when using with Same for following: FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! And then I also ran: RUN_SLOW=1 pytest tests/pipelines/stable_diffusion/test_stable_diffusion.py::StableDiffusionPipelineDeviceMapTests Everything passes. |
|
||
for param_name, param in named_buffers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep this or equivalent elsewhere, context: #10523
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes I did should also cover this use case. The test you added should pass with my PR. The is mainly due to adding the dispatch_model function at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hlky are we cool here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are passing so all good
Co-authored-by: Sayak Paul <[email protected]>
…odel-loading-refactor
Co-authored-by: YiYi Xu <[email protected]>
…odel-loading-refactor
Done ! Please check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽 Thanks @SunMarc. Just need to replace the model repo id in the keep_in_fp32 tests so they pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a bunch, @SunMarc! I left some mixed comments across the board. But I don't think any of them are major.
Really great work!
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm( | |||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 would you like to run the slow single-file tests to ensure we're not breaking anything here (if not already)?
tests/models/test_modeling_common.py
Outdated
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Linear): | ||
if name in model._keep_in_fp32_modules: | ||
self.assertTrue(module.weight.dtype == torch.float32) | ||
else: | ||
self.assertTrue(module.weight.dtype == torch_dtype) | ||
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things:
- Would it make sense to make this test a part of
ModelTesterMixin
? Not strongly opinionated about it. - Let's make sure we also perform inference -- this helps us validate the effectiveness of
_keep_modules_in_fp32
even better. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to make this test a part of ModelTesterMixin? Not strongly opinionated about it.
It can make sense to add a test in ModelTesterMixin for the models that do have _keep_in_fp32_modules
specified. However, for now, none of the models use this arg. Maybe in a follow-up PR when a model actually needs this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure we also perform inference -- this helps us validate the effectiveness of _keep_modules_in_fp32 even better. WDYT?
fixed
@@ -136,7 +136,7 @@ def setUp(self): | |||
bnb_4bit_compute_dtype=torch.float16, | |||
) | |||
self.model_4bit = SD3Transformer2DModel.from_pretrained( | |||
self.model_name, subfolder="transformer", quantization_config=nf4_config | |||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to test with "auto" device_map instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would work also but I kept torch_device
for simplicity as we do in transformers. Do you have multi-gpu runner for quantization tests ? We can create multi-gpu tests if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's do that in a separate PR. I can do that. DO you have a reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some final set of questions and we should be good to go.
|
||
for param_name, param in named_buffers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hlky are we cool here?
@@ -136,7 +136,7 @@ def setUp(self): | |||
bnb_4bit_compute_dtype=torch.float16, | |||
) | |||
self.model_4bit = SD3Transformer2DModel.from_pretrained( | |||
self.model_name, subfolder="transformer", quantization_config=nf4_config | |||
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's do that in a separate PR. I can do that. DO you have a reference?
expected_slice_auto = np.array( | ||
[ | ||
0.34179688, | ||
-0.03613281, | ||
0.01428223, | ||
-0.22949219, | ||
-0.49609375, | ||
0.4375, | ||
-0.1640625, | ||
-0.66015625, | ||
0.43164062, | ||
] | ||
) | ||
expected_slice_offload = np.array( | ||
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these changing because of device changes? Cc: @a-r-r-o-w for a double-check.
Co-authored-by: Sayak Paul <[email protected]>
…odel-loading-refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
What does this PR do?
Fixes #10013 . This PR refactors model loading in diffusers. Here's a list of major changes in this PR.
low_cpu_mem_usage=True
andlow_cpu_mem_usage = False
). We don't rely onload_checkpoint_and_dispatch
anymore and we don't merge sharded checkpoint also.keep_module_in_fp32
support for sharded checkpointsFor
low_cpu_mem_usage = False
:assign_to_params_buffers
). I didn't benchmarked it but it should be as fast aslow_cpu_mem_usage=True
or maybe even faster. We did a similar PR in transformers thanks to @muellerzr.torch_dtype support
We don't initialize anymore the model in fp32 then cast the model to a specific dtype after finishing to load the weights.For
low_cpu_mem_usage = True
ordevice_map!=None
:load_checkpoint_and_dispatch
device_map
support for quantizationdispatch_model
( the test you added is passing cc @hlky )Single format file:
from_pretrained
. This way we have the same features as this function (device_map, quantization ...). Feel free to share your opinion @DN6, I didn't expect to touch this but I felt that we could simplify a bitTODO (some items can be done in follow-up PRs):
low_cpu_mem_usage=False
since we are initializing the whole model)Please let me know your thoughts on the PR !
cc @sayakpaul, @DN6 , @yiyixuxu , @hlky , @a-r-r-o-w