Skip to content

Commit 5d2d239

Browse files
authored
Fix inconsistent random transform in instruct pix2pix (#10698)
* Update train_instruct_pix2pix.py Fix inconsistent random transform in instruct_pix2pix * Update train_instruct_pix2pix_sdxl.py
1 parent 1ae9b05 commit 5d2d239

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def preprocess_images(examples):
695695
)
696696
# We need to ensure that the original and the edited images undergo the same
697697
# augmentation transforms.
698-
images = np.concatenate([original_images, edited_images])
698+
images = np.stack([original_images, edited_images])
699699
images = torch.tensor(images)
700700
images = 2 * (images / 255) - 1
701701
return train_transforms(images)
@@ -706,7 +706,7 @@ def preprocess_train(examples):
706706
# Since the original and edited images were concatenated before
707707
# applying the transformations, we need to separate them and reshape
708708
# them accordingly.
709-
original_images, edited_images = preprocessed_images.chunk(2)
709+
original_images, edited_images = preprocessed_images
710710
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
711711
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
712712

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def preprocess_images(examples):
766766
)
767767
# We need to ensure that the original and the edited images undergo the same
768768
# augmentation transforms.
769-
images = np.concatenate([original_images, edited_images])
769+
images = np.stack([original_images, edited_images])
770770
images = torch.tensor(images)
771771
images = 2 * (images / 255) - 1
772772
return train_transforms(images)
@@ -906,7 +906,7 @@ def preprocess_train(examples):
906906
# Since the original and edited images were concatenated before
907907
# applying the transformations, we need to separate them and reshape
908908
# them accordingly.
909-
original_images, edited_images = preprocessed_images.chunk(2)
909+
original_images, edited_images = preprocessed_images
910910
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
911911
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
912912

0 commit comments

Comments
 (0)