Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn: Fix reuse tensors #911

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions common/nn/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ func Neg(x *Tensor) *Tensor {
}

// Add returns the element-wise sum of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor.
func Add(x0, x1 *Tensor) *Tensor {
if len(x0.shape) < len(x1.shape) {
x0, x1 = x1, x0
}
for i := 0; i < len(x1.shape); i++ {
if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] {
panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor")
func Add(x0 *Tensor, x ...*Tensor) *Tensor {
output := x0
for _, x1 := range x {
if len(x0.shape) < len(x1.shape) {
x0, x1 = x1, x0
}
for i := 0; i < len(x1.shape); i++ {
if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] {
panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor")
}
}
output = apply(&add{}, output, x1)
}
return apply(&add{}, x0, x1)
return output
}

// Sub returns the element-wise difference of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor.
Expand Down
125 changes: 124 additions & 1 deletion common/nn/op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func TestReshape(t *testing.T) {
assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data)
}

func TestReuse(t *testing.T) {
func TestReuseLeaf(t *testing.T) {
// x + x
x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
y := Add(x, x)
Expand All @@ -544,3 +544,126 @@ func TestReuse(t *testing.T) {
dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, x) }, x)
allClose(t, x.grad, dx)
}

func TestReuseNode(t *testing.T) {
// x^2 + x^2
x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
temp := Pow(x, NewVariable([]float32{2}))
y := Add(temp, temp)
assert.Equal(t, []float32{2, 8, 18, 32, 50, 72}, y.data)

// Test gradient
y.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
temp := Pow(x, NewVariable([]float32{2}))
return Add(temp, temp)
}, x)
allClose(t, x.grad, dx)
}

func TestSphere(t *testing.T) {
// x^2 + y^2
x := NewScalar(1)
y := NewScalar(1)
z := Add(Mul(x, x), Mul(y, y))
assert.Equal(t, []float32{2}, z.data)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, x)
dy := numericalDiff(func(y *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}

func TestMatyas(t *testing.T) {
// 0.26 * (x^2 + y^2) - 0.48 * x * y
x := NewScalar(1)
y := NewScalar(1)
z := Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
assert.InDeltaSlice(t, []float32{0.04}, z.data, 1e-6)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
}, x)
dy := numericalDiff(func(y *Tensor) *Tensor {
return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
}, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}

func TestGoldsteinPrice(t *testing.T) {
// (1 + (x + y + 1)^2 * (19 - 14x + 3x^2 - 14y + 6xy + 3y^2)) * (30 + (2x - 3y)^2 * (18 - 32x + 12x^2 + 48y - 36xy + 27y^2))
x := NewScalar(1)
y := NewScalar(1)
z := Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
assert.InDeltaSlice(t, []float32{1876}, z.data, 1e-6)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
return Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
}, x)
dy := numericalDiff(func(y *Tensor) *Tensor {
return Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
}, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}
5 changes: 4 additions & 1 deletion common/nn/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import (
"fmt"
"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/google/uuid"
"github.com/zhenghaoz/gorse/base/floats"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -209,6 +210,7 @@
func (t *Tensor) Backward() {
t.grad = Ones(t.shape...)
ops := []op{t.op}
seen := mapset.NewSet[op](t.op)
for len(ops) > 0 {
op := ops[0]
ops = ops[1:]
Expand All @@ -225,8 +227,9 @@
} else {
inputs[i].grad.add(grads[i])
}
if inputs[i].op != nil {
if inputs[i].op != nil && !seen.Contains(inputs[i].op) {
ops = append(ops, inputs[i].op)
seen.Add(inputs[i].op)
} else if !inputs[i].requireGrad {
// Clear gradient if the leaf tensor does not require gradient
//inputs[i].grad = nil
Expand Down Expand Up @@ -542,7 +545,7 @@
}
}

func (t *Tensor) transpose() *Tensor {

Check failure on line 548 in common/nn/tensor.go

View workflow job for this annotation

GitHub Actions / lint

func `(*Tensor).transpose` is unused (unused)
if len(t.shape) < 2 {
panic("transpose requires at least 2-D tensor")
}
Expand Down
Loading