Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 00896cf

Browse files
jackm321facebook-github-bot
authored andcommitted
aten::masked_fill (#4708)
Summary: Pull Request resolved: #4708 Add torch_glow support for aten::masked_fill operator Reviewed By: mortzur Differential Revision: D22610504 fbshipit-source-id: 60fb272a95f7fb4758b4258a3b88625e9df49970
1 parent a57918b commit 00896cf

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

torch_glow/src/PyTorchModelLoader.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,15 @@ struct FlattenInputs {
642642
};
643643
};
644644

645+
/// Indexes of aten::masked_fill inputs.
646+
struct MaskedFillInputs {
647+
enum {
648+
input = 0,
649+
mask = 1,
650+
value = 2,
651+
};
652+
};
653+
645654
/// Indexes of aten::topk inputs.
646655
struct TopKInputs {
647656
enum {
@@ -803,6 +812,8 @@ PyTorchModelLoader::buildSymbolsMapping() {
803812
{{"aten::bmm"}, &PyTorchModelLoader::loadBmm},
804813
{{"aten::addmm"}, &PyTorchModelLoader::loadAddMM},
805814
{{"aten::flatten"}, &PyTorchModelLoader::loadFlatten},
815+
{{"aten::masked_fill", "aten::masked_fill_"},
816+
&PyTorchModelLoader::loadMaskedFill},
806817
{{"aten::prelu"}, &PyTorchModelLoader::loadPRelu},
807818
{{"aten::slice"}, &PyTorchModelLoader::loadSlice},
808819
{{"aten::softmax"}, &PyTorchModelLoader::loadSoftMax},
@@ -3522,6 +3533,48 @@ Error PyTorchModelLoader::loadTo(const torch::jit::Node *ptNode) {
35223533
return addValueMapping(outputs[0], in);
35233534
}
35243535

3536+
Error PyTorchModelLoader::loadMaskedFill(const torch::jit::Node *ptNode) {
3537+
auto inputs = ptNode->inputs();
3538+
auto outputs = ptNode->outputs();
3539+
RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 3, outputs, 1));
3540+
3541+
glow::NodeValue in;
3542+
ASSIGN_VALUE_OR_RETURN_ERR(
3543+
in, getGlowNodeValueForValue(inputs[MaskedFillInputs::input]));
3544+
3545+
glow::NodeValue mask;
3546+
ASSIGN_VALUE_OR_RETURN_ERR(
3547+
mask, getGlowNodeValueForValue(inputs[MaskedFillInputs::mask]));
3548+
3549+
size_t inSize = in.dims().size();
3550+
size_t maskSize = mask.dims().size();
3551+
3552+
RETURN_ERR_IF_NOT(
3553+
inSize >= maskSize,
3554+
strFormat("masked_fill must have inputs at least as large as mask got "
3555+
"input of size %zu and mask of size %zu",
3556+
inSize, maskSize));
3557+
3558+
size_t maskBroadcastAxis = inSize - maskSize;
3559+
if (maskBroadcastAxis > 0) {
3560+
mask = F_.createBroadcast("broadcast", mask, in.dims(), maskBroadcastAxis)
3561+
->getNthResult(0);
3562+
}
3563+
3564+
float value;
3565+
ASSIGN_VALUE_OR_RETURN_ERR(value, iValToDouble(getGlowIValueForValue(
3566+
inputs[MaskedFillInputs::value])));
3567+
3568+
auto valueSplat =
3569+
F_.createSplat("masked_fill_value",
3570+
F_.getParent()->uniqueType(ElemKind::FloatTy, in.dims()),
3571+
value)
3572+
->getResult();
3573+
3574+
auto out = F_.createSelect("masked_fill", mask, valueSplat, in);
3575+
return addValueMapping(outputs[0], out);
3576+
}
3577+
35253578
Error PyTorchModelLoader::loadFlatten(const torch::jit::Node *ptNode) {
35263579
auto inputs = ptNode->inputs();
35273580
auto outputs = ptNode->outputs();

torch_glow/src/PyTorchModelLoader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ class PyTorchModelLoader {
584584
/// \returns error on failure.
585585
Error loadFlatten(const torch::jit::Node *ptNode);
586586

587+
/// Load a PyTorch aten::masked_fill node.
588+
/// \returns error on failure.
589+
Error loadMaskedFill(const torch::jit::Node *ptNode);
590+
587591
/// Load a PyTorch topK node.
588592
/// \returns error on failure.
589593
Error loadTopK(const torch::jit::Node *ptNode);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import torch
4+
from tests.utils import jitVsGlow
5+
import unittest
6+
7+
8+
class TestMaskedFill(unittest.TestCase):
9+
def test_masked_fill_basic(self):
10+
"""Test of the PyTorch aten::masked_fill op on Glow."""
11+
12+
def masked_fill(a, mask):
13+
return torch.masked_fill(a + a, mask, 42.0)
14+
15+
x = torch.randn([3])
16+
mask = torch.tensor([True, False, True], dtype=torch.bool)
17+
18+
jitVsGlow(masked_fill, x, mask, expected_fused_ops={"aten::masked_fill"})
19+
20+
def test_masked_fill_broadcasted(self):
21+
"""Test of the PyTorch aten::masked_fill op on Glow with a
22+
broadcasted mask"""
23+
24+
def masked_fill(a, mask):
25+
return torch.masked_fill(a + a, mask, 42.0)
26+
27+
x = torch.randn([4, 1, 3])
28+
mask = torch.tensor([True, False, True], dtype=torch.bool)
29+
30+
jitVsGlow(masked_fill, x, mask, expected_fused_ops={"aten::masked_fill"})
31+
32+
def test_masked_fill_inplace(self):
33+
"""Test of the PyTorch aten::masked_fill_ op on Glow"""
34+
35+
def masked_fill(a, mask):
36+
b = a + a
37+
b.masked_fill_(mask, 42.0)
38+
return b
39+
40+
x = torch.randn([3])
41+
mask = torch.tensor([True, False, True], dtype=torch.bool)
42+
43+
jitVsGlow(masked_fill, x, mask, expected_fused_ops={"aten::masked_fill_"})

0 commit comments

Comments
 (0)