Skip to content

Commit ee28831

Browse files
davidriazatifacebook-github-bot
davidriazati
authored andcommitted
[jit] Fix aug assign for non-tensor attributes (pytorch#32993)
Summary: Instead of erroring out this de-sugars augmented assignments to class members from `self.a += 1` to `self.a = self.a + 1`. Fixes pytorch#32973 ](https://our.intern.facebook.com/intern/diff/19737636/) Pull Request resolved: pytorch#32993 Pulled By: driazati Differential Revision: D19737636 fbshipit-source-id: 07307cde88d8c348a7affdafe26db21c74e28ec0
1 parent fa80299 commit ee28831

File tree

4 files changed

+175
-29
lines changed

4 files changed

+175
-29
lines changed

test/test_jit.py

+82
Original file line numberDiff line numberDiff line change
@@ -5294,6 +5294,88 @@ def foo(z):
52945294
return y[0][1]
52955295
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
52965296

5297+
def test_nested_aug_assign(self):
5298+
@torch.jit.script
5299+
class SomeClass(object):
5300+
def __init__(self):
5301+
self.num = 99
5302+
5303+
def __iadd__(self, x):
5304+
# type: (int)
5305+
self.num += x
5306+
return self
5307+
5308+
def __eq__(self, other):
5309+
# type: (SomeClass) -> bool
5310+
return self.num == other.num
5311+
5312+
@torch.jit.script
5313+
class SomeOutOfPlaceClass(object):
5314+
def __init__(self):
5315+
self.num = 99
5316+
5317+
def __add__(self, x):
5318+
# type: (int)
5319+
self.num = x
5320+
return self
5321+
5322+
def __eq__(self, other):
5323+
# type: (SomeClass) -> bool
5324+
return self.num == other.num
5325+
5326+
class Child(nn.Module):
5327+
def __init__(self):
5328+
super().__init__()
5329+
self.x = 2
5330+
self.o = SomeClass()
5331+
self.oop = SomeOutOfPlaceClass()
5332+
self.list = [1, 2, 3]
5333+
5334+
class A(nn.Module):
5335+
def __init__(self):
5336+
super().__init__()
5337+
self.child = Child()
5338+
5339+
def forward(self):
5340+
self.child.x += 1
5341+
self.child.o += 5
5342+
self.child.oop += 5
5343+
some_list = [1, 2]
5344+
self.child.list += some_list
5345+
self.child.list *= 2
5346+
return self.child.x, self.child.o, self.child.list, self.child.oop
5347+
5348+
a = A()
5349+
sa = torch.jit.script(A())
5350+
eager_result = a()
5351+
script_result = sa()
5352+
self.assertEqual(eager_result, script_result)
5353+
self.assertEqual(a.child.x, sa.child.x)
5354+
self.assertEqual(a.child.o, sa.child.o)
5355+
self.assertEqual(a.child.list, sa.child.list)
5356+
5357+
@torch.jit.script
5358+
class SomeNonAddableClass(object):
5359+
def __init__(self):
5360+
self.num = 99
5361+
5362+
def __eq__(self, other):
5363+
# type: (SomeClass) -> bool
5364+
return self.num == other.num
5365+
5366+
# with self.assertRaisesRegex(RuntimeError, "")
5367+
class A(nn.Module):
5368+
def __init__(self):
5369+
super().__init__()
5370+
self.x = SomeNonAddableClass()
5371+
5372+
def forward(self):
5373+
self.x += SomeNonAddableClass()
5374+
return self.x
5375+
5376+
with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
5377+
torch.jit.script(A())
5378+
52975379
def test_nested_list_construct(self):
52985380
def foo():
52995381
return [[4]] + [[4, 5]]

torch/csrc/jit/register_prim_ops.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,26 @@ int listInplaceAdd(Stack& stack) {
16801680
return 0;
16811681
}
16821682

1683+
template <class T>
1684+
int listMulIntLeftInPlace(Stack& stack) {
1685+
int64_t n = pop(stack).to<int64_t>();
1686+
c10::List<T> list = pop(stack).to<c10::List<T>>();
1687+
1688+
if (n <= 0) {
1689+
list.clear();
1690+
} else if (n > 1) {
1691+
size_t list_size = list.size();
1692+
for (auto i = 1; i < n; i++) {
1693+
for (size_t j = 0; j < list_size; j++) {
1694+
list.push_back(list.get(j));
1695+
}
1696+
}
1697+
}
1698+
1699+
push(stack, std::move(list));
1700+
return 0;
1701+
}
1702+
16831703
template <class T>
16841704
int listMulIntLeft(Stack& stack) {
16851705
int64_t n = pop(stack).to<int64_t>();
@@ -2295,6 +2315,10 @@ RegisterOperators reg2({
22952315
Operator( \
22962316
"aten::mul(int n, " decl_type "[] l) -> " decl_type "[]", \
22972317
listMulIntRight<c_type::value_type>, \
2318+
aliasAnalysisFromSchema()), \
2319+
Operator( \
2320+
"aten::mul_(" decl_type "[](a!) l, int n) -> " decl_type "[](a!)", \
2321+
listMulIntLeftInPlace<c_type::value_type>, \
22982322
aliasAnalysisFromSchema())
22992323

23002324
CREATE_LIST_OPS("int", c10::List<int64_t>),

torch/csrc/jit/script/ir_emitter.cpp

+68-29
Original file line numberDiff line numberDiff line change
@@ -1780,22 +1780,37 @@ struct to_ir {
17801780
// If the RHS is a tensor, return the corresponding ATen in-place op
17811781
// If it's a list of scalars, then return the corresponding list augment op
17821782
Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) {
1783-
if (type->cast<ListType>()) { // Lists also have in-place ops.
1784-
switch (stmt.aug_op()) {
1785-
case '+':
1786-
return aten::add_;
1787-
}
1783+
bool use_inplace_op = type->isSubtypeOf(TensorType::get()) ||
1784+
type->kind() == TypeKind::ListType;
1785+
switch (stmt.aug_op()) {
1786+
case '+':
1787+
return use_inplace_op ? aten::add_ : aten::add;
1788+
case '-':
1789+
return use_inplace_op ? aten::sub_ : aten::sub;
1790+
case '/':
1791+
return use_inplace_op ? aten::div_ : aten::div;
1792+
case '*':
1793+
return use_inplace_op ? aten::mul_ : aten::mul;
1794+
default:
1795+
throw ErrorReport(stmt)
1796+
<< "Unknown augmented assignment: " << kindToString(stmt.aug_op());
17881797
}
1789-
bool isTensor = type->isSubtypeOf(TensorType::get());
1798+
}
1799+
1800+
// Get a pair of <in place magic method name, out of place magic method name>
1801+
// since the out of place method is called if the in place method is not
1802+
// present
1803+
std::pair<std::string, std::string> getAugMagicMethod(const AugAssign& stmt) {
17901804
switch (stmt.aug_op()) {
17911805
case '+':
1792-
return isTensor ? aten::add_ : aten::add;
1806+
return std::make_pair(std::string("__iadd__"), std::string("__add__"));
17931807
case '-':
1794-
return isTensor ? aten::sub_ : aten::sub;
1808+
return std::make_pair(std::string("__isub__"), std::string("__sub__"));
17951809
case '/':
1796-
return isTensor ? aten::div_ : aten::div;
1810+
return std::make_pair(
1811+
std::string("__itruediv__"), std::string("__truediv__"));
17971812
case '*':
1798-
return isTensor ? aten::mul_ : aten::mul;
1813+
return std::make_pair(std::string("__imul__"), std::string("__mul__"));
17991814
default:
18001815
throw ErrorReport(stmt)
18011816
<< "Unknown augmented assignment: " << kindToString(stmt.aug_op());
@@ -1831,34 +1846,58 @@ struct to_ir {
18311846
//
18321847
// def forward():
18331848
// self.num_batches += 1
1834-
//
1835-
// In this case we will only consider the scenario that the module
1836-
// buffer type is a tensor, and we emit the corresponding tensor
1837-
// in place op, and throw error for other unsupported types
18381849
void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
18391850
const auto lhs = Select(stmt.lhs());
1840-
const auto lhsSugaredVar =
1841-
environment_stack->getSugaredVar(Var(lhs.value()).name());
1851+
auto lhsSugaredVar = emitSugaredExpr(lhs.value(), 1);
18421852
const auto lhsValue =
18431853
lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
18441854
->asValue(lhs.range(), method);
1845-
if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1846-
// for module parameter/buffer assignment, only consider tensor types,
1847-
// emit the corresponding in-place op
1848-
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1849-
const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
1850-
emitBuiltinCall(
1855+
if (lhsValue->type()->kind() == TypeKind::ClassType) {
1856+
// Call `__iadd__` so updates happen in place on class types
1857+
// https://docs.python.org/3/reference/datamodel.html#object.__iadd__
1858+
std::string in_place_method_name;
1859+
std::string out_of_place_method_name;
1860+
std::tie(in_place_method_name, out_of_place_method_name) =
1861+
getAugMagicMethod(stmt);
1862+
const auto rhs = emitExpr(stmt.rhs());
1863+
1864+
// Determine whether to use __iadd__ or __add__ (use __add__ only if
1865+
// __iadd__ is not present)
1866+
auto type = lhsValue->type()->expect<ClassType>();
1867+
std::string magic_method_name;
1868+
if (type->getMethod(in_place_method_name)) {
1869+
magic_method_name = in_place_method_name;
1870+
} else if (type->getMethod(out_of_place_method_name)) {
1871+
magic_method_name = out_of_place_method_name;
1872+
} else {
1873+
throw ErrorReport(stmt.range())
1874+
<< "Cannot emit inplace op on " << type->python_str()
1875+
<< " since it does not define an " << in_place_method_name << " or "
1876+
<< out_of_place_method_name << " method";
1877+
}
1878+
1879+
// Insert call to the magic method
1880+
MethodValue method_value(lhsValue, magic_method_name);
1881+
auto result = method_value.call(stmt.range(), method, {rhs}, {}, 0)
1882+
->asValue(stmt.range(), method);
1883+
1884+
// x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
1885+
// __iadd__ is not present, so set the value to the function's return
1886+
// value
1887+
lhsSugaredVar->setAttr(
1888+
stmt.range(), method, lhs.selector().name(), result);
1889+
} else {
1890+
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()))
1891+
.value(*method.graph());
1892+
auto rhsValue = emitBuiltinCall(
18511893
stmt.range(),
18521894
*method.graph(),
18531895
getAugOp(stmt, lhsValue->type()),
1854-
{rhs},
1896+
{lhsValue, rhs},
18551897
{},
1856-
self);
1857-
1858-
} else {
1859-
throw ErrorReport(stmt.lhs())
1860-
<< "left-hand side of augmented assignment to module "
1861-
<< "parameters/buffers can only be tensor types";
1898+
/*self=*/c10::nullopt);
1899+
lhsSugaredVar->setAttr(
1900+
stmt.range(), method, lhs.selector().name(), rhsValue);
18621901
}
18631902
}
18641903

torch/csrc/jit/script/sugared_value.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ std::shared_ptr<SugaredValue> MagicMethod::call(
525525
->call(loc, m, inputs.slice(1), attributes, n_binders);
526526
}
527527
}
528+
TORCH_INTERNAL_ASSERT(base_value_);
528529
return base_value_->call(loc, m, inputs, attributes, n_binders);
529530
}
530531

0 commit comments

Comments
 (0)