Skip to content

Commit 24b8b5c

Browse files
pcuencapatrickvonplatenpatil-suraj
authored
mps: Alternative implementation for repeat_interleave (#766)
* mps: alt. implementation for repeat_interleave * style * Bump mps version of PyTorch in the documentation. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Simplify: do not check for device. * style * Fix repeat dimensions: - The unconditional embeddings are always created from a single prompt. - I was shadowing the batch_size var. * Split long lines as suggested by Suraj. Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 757babf commit 24b8b5c

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

docs/source/optimization/mps.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
1919
- Mac computer with Apple silicon (M1/M2) hardware.
2020
- macOS 12.3 or later.
2121
- arm64 version of Python.
22-
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
22+
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
2323

2424
## Inference Pipeline
2525

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ def __call__(
218218
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
219219
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
220220

221-
# duplicate text embeddings for each generation per prompt
222-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
221+
# duplicate text embeddings for each generation per prompt, using mps friendly method
222+
bs_embed, seq_len, _ = text_embeddings.shape
223+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
224+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
223225

224226
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
225227
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -256,8 +258,10 @@ def __call__(
256258
)
257259
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
258260

259-
# duplicate unconditional embeddings for each generation per prompt
260-
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
261+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
262+
seq_len = uncond_embeddings.shape[1]
263+
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
264+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
261265

262266
# For classifier free guidance, we need to do two forward passes.
263267
# Here we concatenate the unconditional and text embeddings into a single batch

0 commit comments

Comments
 (0)