@@ -1780,22 +1780,37 @@ struct to_ir {
1780
1780
// If the RHS is a tensor, return the corresponding ATen in-place op
1781
1781
// If it's a list of scalars, then return the corresponding list augment op
1782
1782
Symbol getAugOp (const AugAssign& stmt, const TypePtr& type) {
1783
- if (type->cast <ListType>()) { // Lists also have in-place ops.
1784
- switch (stmt.aug_op ()) {
1785
- case ' +' :
1786
- return aten::add_;
1787
- }
1783
+ bool use_inplace_op = type->isSubtypeOf (TensorType::get ()) ||
1784
+ type->kind () == TypeKind::ListType;
1785
+ switch (stmt.aug_op ()) {
1786
+ case ' +' :
1787
+ return use_inplace_op ? aten::add_ : aten::add;
1788
+ case ' -' :
1789
+ return use_inplace_op ? aten::sub_ : aten::sub;
1790
+ case ' /' :
1791
+ return use_inplace_op ? aten::div_ : aten::div ;
1792
+ case ' *' :
1793
+ return use_inplace_op ? aten::mul_ : aten::mul;
1794
+ default :
1795
+ throw ErrorReport (stmt)
1796
+ << " Unknown augmented assignment: " << kindToString (stmt.aug_op ());
1788
1797
}
1789
- bool isTensor = type->isSubtypeOf (TensorType::get ());
1798
+ }
1799
+
1800
+ // Get a pair of <in place magic method name, out of place magic method name>
1801
+ // since the out of place method is called if the in place method is not
1802
+ // present
1803
+ std::pair<std::string, std::string> getAugMagicMethod (const AugAssign& stmt) {
1790
1804
switch (stmt.aug_op ()) {
1791
1805
case ' +' :
1792
- return isTensor ? aten::add_ : aten::add ;
1806
+ return std::make_pair ( std::string ( " __iadd__ " ), std::string ( " __add__ " )) ;
1793
1807
case ' -' :
1794
- return isTensor ? aten::sub_ : aten::sub ;
1808
+ return std::make_pair ( std::string ( " __isub__ " ), std::string ( " __sub__ " )) ;
1795
1809
case ' /' :
1796
- return isTensor ? aten::div_ : aten::div ;
1810
+ return std::make_pair (
1811
+ std::string (" __itruediv__" ), std::string (" __truediv__" ));
1797
1812
case ' *' :
1798
- return isTensor ? aten::mul_ : aten::mul ;
1813
+ return std::make_pair ( std::string ( " __imul__ " ), std::string ( " __mul__ " )) ;
1799
1814
default :
1800
1815
throw ErrorReport (stmt)
1801
1816
<< " Unknown augmented assignment: " << kindToString (stmt.aug_op ());
@@ -1831,34 +1846,58 @@ struct to_ir {
1831
1846
//
1832
1847
// def forward():
1833
1848
// self.num_batches += 1
1834
- //
1835
- // In this case we will only consider the scenario that the module
1836
- // buffer type is a tensor, and we emit the corresponding tensor
1837
- // in place op, and throw error for other unsupported types
1838
1849
void emitAugAssignmentToSelectVar (const AugAssign& stmt) {
1839
1850
const auto lhs = Select (stmt.lhs ());
1840
- const auto lhsSugaredVar =
1841
- environment_stack->getSugaredVar (Var (lhs.value ()).name ());
1851
+ auto lhsSugaredVar = emitSugaredExpr (lhs.value (), 1 );
1842
1852
const auto lhsValue =
1843
1853
lhsSugaredVar->attr (lhs.range (), method, lhs.selector ().name ())
1844
1854
->asValue (lhs.range (), method);
1845
- if (lhsValue->type ()->isSubtypeOf (TensorType::get ())) {
1846
- // for module parameter/buffer assignment, only consider tensor types,
1847
- // emit the corresponding in-place op
1848
- const auto rhs = NamedValue (stmt.rhs ().range (), emitExpr (stmt.rhs ()));
1849
- const auto self = NamedValue (stmt.lhs ().range (), " self" , lhsValue);
1850
- emitBuiltinCall (
1855
+ if (lhsValue->type ()->kind () == TypeKind::ClassType) {
1856
+ // Call `__iadd__` so updates happen in place on class types
1857
+ // https://docs.python.org/3/reference/datamodel.html#object.__iadd__
1858
+ std::string in_place_method_name;
1859
+ std::string out_of_place_method_name;
1860
+ std::tie (in_place_method_name, out_of_place_method_name) =
1861
+ getAugMagicMethod (stmt);
1862
+ const auto rhs = emitExpr (stmt.rhs ());
1863
+
1864
+ // Determine whether to use __iadd__ or __add__ (use __add__ only if
1865
+ // __iadd__ is not present)
1866
+ auto type = lhsValue->type ()->expect <ClassType>();
1867
+ std::string magic_method_name;
1868
+ if (type->getMethod (in_place_method_name)) {
1869
+ magic_method_name = in_place_method_name;
1870
+ } else if (type->getMethod (out_of_place_method_name)) {
1871
+ magic_method_name = out_of_place_method_name;
1872
+ } else {
1873
+ throw ErrorReport (stmt.range ())
1874
+ << " Cannot emit inplace op on " << type->python_str ()
1875
+ << " since it does not define an " << in_place_method_name << " or "
1876
+ << out_of_place_method_name << " method" ;
1877
+ }
1878
+
1879
+ // Insert call to the magic method
1880
+ MethodValue method_value (lhsValue, magic_method_name);
1881
+ auto result = method_value.call (stmt.range (), method, {rhs}, {}, 0 )
1882
+ ->asValue (stmt.range (), method);
1883
+
1884
+ // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
1885
+ // __iadd__ is not present, so set the value to the function's return
1886
+ // value
1887
+ lhsSugaredVar->setAttr (
1888
+ stmt.range (), method, lhs.selector ().name (), result);
1889
+ } else {
1890
+ const auto rhs = NamedValue (stmt.rhs ().range (), emitExpr (stmt.rhs ()))
1891
+ .value (*method.graph ());
1892
+ auto rhsValue = emitBuiltinCall (
1851
1893
stmt.range (),
1852
1894
*method.graph (),
1853
1895
getAugOp (stmt, lhsValue->type ()),
1854
- {rhs},
1896
+ {lhsValue, rhs},
1855
1897
{},
1856
- self);
1857
-
1858
- } else {
1859
- throw ErrorReport (stmt.lhs ())
1860
- << " left-hand side of augmented assignment to module "
1861
- << " parameters/buffers can only be tensor types" ;
1898
+ /* self=*/ c10::nullopt);
1899
+ lhsSugaredVar->setAttr (
1900
+ stmt.range (), method, lhs.selector ().name (), rhsValue);
1862
1901
}
1863
1902
}
1864
1903
0 commit comments