Skip to content

Commit 28d70f1

Browse files
authored
fix exception around NSFW filter on flax stable diffusion (huggingface#5675)
1 parent 1081cd0 commit 28d70f1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,13 @@ def __call__(
410410

411411
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
412412
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
413-
images = np.asarray(images)
413+
images = np.asarray(images).copy()
414414

415415
# block images
416416
if any(has_nsfw_concept):
417417
for i, is_nsfw in enumerate(has_nsfw_concept):
418418
if is_nsfw:
419-
images[i] = np.asarray(images_uint8_casted[i])
419+
images[i, 0] = np.asarray(images_uint8_casted[i])
420420

421421
images = images.reshape(num_devices, batch_size, height, width, 3)
422422
else:

0 commit comments

Comments
 (0)