Skip to content

Commit bd96a08

Browse files
authored
[train_dreambooth_lora.py] Set LANCZOS as default interpolation mode for resizing (#11421)
* Set LANCZOS as default interpolation mode for resizing * [train_dreambooth_lora.py] Set LANCZOS as default interpolation mode for resizing
1 parent f00a995 commit bd96a08

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

examples/dreambooth/train_dreambooth_lora.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,15 @@ def parse_args(input_args=None):
524524
default=4,
525525
help=("The dimension of the LoRA update matrices."),
526526
)
527+
parser.add_argument(
528+
"--image_interpolation_mode",
529+
type=str,
530+
default="lanczos",
531+
choices=[
532+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
533+
],
534+
help="The image interpolation method to use for resizing images.",
535+
)
527536

528537
if input_args is not None:
529538
args = parser.parse_args(input_args)
@@ -601,9 +610,13 @@ def __init__(
601610
else:
602611
self.class_data_root = None
603612

613+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
614+
if interpolation is None:
615+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
616+
604617
self.image_transforms = transforms.Compose(
605618
[
606-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
619+
transforms.Resize(size, interpolation=interpolation),
607620
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
608621
transforms.ToTensor(),
609622
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)