Skip to content

Commit 2bc82d6

Browse files
authored
DiffusionPipeline mixin to+FromOriginalModelMixin/FromSingleFileMixin from_single_file type hint (#10811)
* DiffusionPipeline mixin `to` type hint * FromOriginalModelMixin from_single_file * FromSingleFileMixin from_single_file
1 parent 924f880 commit 2bc82d6

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from huggingface_hub import snapshot_download
2020
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
2121
from packaging import version
22+
from typing_extensions import Self
2223

2324
from ..utils import deprecate, is_transformers_available, logging
2425
from .single_file_utils import (
@@ -269,7 +270,7 @@ class FromSingleFileMixin:
269270

270271
@classmethod
271272
@validate_hf_hub_args
272-
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
273+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
273274
r"""
274275
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
275276
format. The pipeline is set in evaluation mode (`model.eval()`) by default.

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from huggingface_hub.utils import validate_hf_hub_args
22+
from typing_extensions import Self
2223

2324
from ..quantizers import DiffusersAutoQuantizer
2425
from ..utils import deprecate, is_accelerate_available, logging
@@ -148,7 +149,7 @@ class FromOriginalModelMixin:
148149

149150
@classmethod
150151
@validate_hf_hub_args
151-
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
152+
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
152153
r"""
153154
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
154155
is set in evaluation mode (`model.eval()`) by default.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def is_saveable_module(name, value):
324324
create_pr=create_pr,
325325
)
326326

327-
def to(self, *args, **kwargs):
327+
def to(self, *args, **kwargs) -> Self:
328328
r"""
329329
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
330330
arguments of `self.to(*args, **kwargs).`

0 commit comments

Comments
 (0)