Skip to content

Commit 0a9f541

Browse files
committed
fix backward for max min
1 parent e6821e3 commit 0a9f541

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
2121
]
2222

23-
__version__ = '1.3.0'
23+
__version__ = '1.3.1'
2424
url = 'https://github.com/rusty1s/pytorch_scatter'
2525

2626
install_requires = []

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .max import scatter_max
88
from .min import scatter_min
99

10-
__version__ = '1.3.0'
10+
__version__ = '1.3.1'
1111

1212
__all__ = [
1313
'scatter_add',

torch_scatter/max.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg):
2424

2525
grad_src = None
2626
if ctx.needs_input_grad[1]:
27-
grad_src = grad_out.new_zeros(index.size())
28-
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
27+
size = list(index.size())
28+
size[ctx.dim] += 1
29+
grad_src = grad_out.new_zeros(size)
30+
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
31+
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
2932

3033
return None, grad_src, None, None
3134

torch_scatter/min.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg):
2424

2525
grad_src = None
2626
if ctx.needs_input_grad[1]:
27-
grad_src = grad_out.new_zeros(index.size())
28-
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
27+
size = list(index.size())
28+
size[ctx.dim] += 1
29+
grad_src = grad_out.new_zeros(size)
30+
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
31+
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
2932

3033
return None, grad_src, None, None
3134

0 commit comments

Comments
 (0)