@@ -1780,22 +1780,37 @@ struct to_ir {
17801780 // If the RHS is a tensor, return the corresponding ATen in-place op
17811781 // If it's a list of scalars, then return the corresponding list augment op
17821782 Symbol getAugOp (const AugAssign& stmt, const TypePtr& type) {
1783- if (type->cast <ListType>()) { // Lists also have in-place ops.
1784- switch (stmt.aug_op ()) {
1785- case ' +' :
1786- return aten::add_;
1787- }
1783+ bool use_inplace_op = type->isSubtypeOf (TensorType::get ()) ||
1784+ type->kind () == TypeKind::ListType;
1785+ switch (stmt.aug_op ()) {
1786+ case ' +' :
1787+ return use_inplace_op ? aten::add_ : aten::add;
1788+ case ' -' :
1789+ return use_inplace_op ? aten::sub_ : aten::sub;
1790+ case ' /' :
1791+ return use_inplace_op ? aten::div_ : aten::div;
1792+ case ' *' :
1793+ return use_inplace_op ? aten::mul_ : aten::mul;
1794+ default :
1795+ throw ErrorReport (stmt)
1796+ << " Unknown augmented assignment: " << kindToString (stmt.aug_op ());
17881797 }
1789- bool isTensor = type->isSubtypeOf (TensorType::get ());
1798+ }
1799+
1800+ // Get a pair of <in place magic method name, out of place magic method name>
1801+ // since the out of place method is called if the in place method is not
1802+ // present
1803+ std::pair<std::string, std::string> getAugMagicMethod (const AugAssign& stmt) {
17901804 switch (stmt.aug_op ()) {
17911805 case ' +' :
1792- return isTensor ? aten::add_ : aten::add ;
1806+ return std::make_pair ( std::string ( " __iadd__ " ), std::string ( " __add__ " )) ;
17931807 case ' -' :
1794- return isTensor ? aten::sub_ : aten::sub ;
1808+ return std::make_pair ( std::string ( " __isub__ " ), std::string ( " __sub__ " )) ;
17951809 case ' /' :
1796- return isTensor ? aten::div_ : aten::div;
1810+ return std::make_pair (
1811+ std::string (" __itruediv__" ), std::string (" __truediv__" ));
17971812 case ' *' :
1798- return isTensor ? aten::mul_ : aten::mul ;
1813+ return std::make_pair ( std::string ( " __imul__ " ), std::string ( " __mul__ " )) ;
17991814 default :
18001815 throw ErrorReport (stmt)
18011816 << " Unknown augmented assignment: " << kindToString (stmt.aug_op ());
@@ -1831,34 +1846,58 @@ struct to_ir {
18311846 //
18321847 // def forward():
18331848 // self.num_batches += 1
1834- //
1835- // In this case we will only consider the scenario that the module
1836- // buffer type is a tensor, and we emit the corresponding tensor
1837- // in place op, and throw error for other unsupported types
18381849 void emitAugAssignmentToSelectVar (const AugAssign& stmt) {
18391850 const auto lhs = Select (stmt.lhs ());
1840- const auto lhsSugaredVar =
1841- environment_stack->getSugaredVar (Var (lhs.value ()).name ());
1851+ auto lhsSugaredVar = emitSugaredExpr (lhs.value (), 1 );
18421852 const auto lhsValue =
18431853 lhsSugaredVar->attr (lhs.range (), method, lhs.selector ().name ())
18441854 ->asValue (lhs.range (), method);
1845- if (lhsValue->type ()->isSubtypeOf (TensorType::get ())) {
1846- // for module parameter/buffer assignment, only consider tensor types,
1847- // emit the corresponding in-place op
1848- const auto rhs = NamedValue (stmt.rhs ().range (), emitExpr (stmt.rhs ()));
1849- const auto self = NamedValue (stmt.lhs ().range (), " self" , lhsValue);
1850- emitBuiltinCall (
1855+ if (lhsValue->type ()->kind () == TypeKind::ClassType) {
1856+ // Call `__iadd__` so updates happen in place on class types
1857+ // https://docs.python.org/3/reference/datamodel.html#object.__iadd__
1858+ std::string in_place_method_name;
1859+ std::string out_of_place_method_name;
1860+ std::tie (in_place_method_name, out_of_place_method_name) =
1861+ getAugMagicMethod (stmt);
1862+ const auto rhs = emitExpr (stmt.rhs ());
1863+
1864+ // Determine whether to use __iadd__ or __add__ (use __add__ only if
1865+ // __iadd__ is not present)
1866+ auto type = lhsValue->type ()->expect <ClassType>();
1867+ std::string magic_method_name;
1868+ if (type->getMethod (in_place_method_name)) {
1869+ magic_method_name = in_place_method_name;
1870+ } else if (type->getMethod (out_of_place_method_name)) {
1871+ magic_method_name = out_of_place_method_name;
1872+ } else {
1873+ throw ErrorReport (stmt.range ())
1874+ << " Cannot emit inplace op on " << type->python_str ()
1875+ << " since it does not define an " << in_place_method_name << " or "
1876+ << out_of_place_method_name << " method" ;
1877+ }
1878+
1879+ // Insert call to the magic method
1880+ MethodValue method_value (lhsValue, magic_method_name);
1881+ auto result = method_value.call (stmt.range (), method, {rhs}, {}, 0 )
1882+ ->asValue (stmt.range (), method);
1883+
1884+ // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
1885+ // __iadd__ is not present, so set the value to the function's return
1886+ // value
1887+ lhsSugaredVar->setAttr (
1888+ stmt.range (), method, lhs.selector ().name (), result);
1889+ } else {
1890+ const auto rhs = NamedValue (stmt.rhs ().range (), emitExpr (stmt.rhs ()))
1891+ .value (*method.graph ());
1892+ auto rhsValue = emitBuiltinCall (
18511893 stmt.range (),
18521894 *method.graph (),
18531895 getAugOp (stmt, lhsValue->type ()),
1854- {rhs},
1896+ {lhsValue, rhs},
18551897 {},
1856- self);
1857-
1858- } else {
1859- throw ErrorReport (stmt.lhs ())
1860- << " left-hand side of augmented assignment to module "
1861- << " parameters/buffers can only be tensor types" ;
1898+ /* self=*/ c10::nullopt );
1899+ lhsSugaredVar->setAttr (
1900+ stmt.range (), method, lhs.selector ().name (), rhsValue);
18621901 }
18631902 }
18641903
0 commit comments