Skip to content

fix: support masked_scatter by lowering path and corner case of maske… #3476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Apr 15, 2025

Description

Implemented support for masked_scatter in the lowering path, referring to this implementation in PyTorch Inductor.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 15, 2025
@github-actions github-actions bot requested a review from apbose April 15, 2025 07:40
@chohk88 chohk88 requested a review from peri044 April 15, 2025 08:34
@chohk88 chohk88 self-assigned this Apr 15, 2025
@chohk88 chohk88 force-pushed the lowering_masked_scatter_ branch from c6be3e7 to bb3ff16 Compare April 18, 2025 08:56
# 6) Reshape the result to match the original broadcasted shape
return replaced.view(input_b.shape)


def get_decompositions(
Copy link
Collaborator

@apbose apbose Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 the implementation looks perfect as such.
Just a slight detail I was wondering. Torch mentions that it supports broadcasting between mask and self tensor, but I see this example working too, basically broadcasting between source and mask. is this supported?

input = torch.zeros(3, 3)
mask = torch.tensor([[1, 0, 1], [0, 0, 1], [0, 0, 0]], dtype=torch.bool)
source = torch.tensor([2, 3, 4])
out = input.masked_scatter_(mask, source)
print("out is", out)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @apbose, great question! I checked PyTorch’s behavior and it:

  1. Broadcasts only input and mask (not source)
  2. Flattens source as-is
  3. Uses only the first mask.sum() elements of source, ignoring any extras

For example, with source=[2,3,4,5] and three True mask positions, PyTorch applies [2,3,4] and drops the 5 without error.

Our current decomposition does exactly that (broadcasts input/mask, flattens source, then cumsum→gather→where), so no code changes are needed. Let me know if you spot any other edge cases!

image
image
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! Thanks for the detailed explanation

@github-actions github-actions bot requested a review from narendasan April 21, 2025 07:35
@zewenli98 zewenli98 added the ciflow/binaries/all Build for all Python Versions label Apr 22, 2025
Copy link

pytorch-bot bot commented Apr 22, 2025

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

@github-actions github-actions bot requested a review from gs-olive April 22, 2025 03:22
@zewenli98 zewenli98 removed the request for review from gs-olive April 22, 2025 03:25
@chohk88 chohk88 merged commit 291b833 into main Apr 22, 2025
170 of 182 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/binaries/all Build for all Python Versions cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests needs-release-cherrypick
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Unsupported ops : torch.ops.aten.masked_scatter.default (Paligemma2)
4 participants