forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstabilization_test.go
106 lines (86 loc) · 2.26 KB
/
stabilization_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package gorgonia
import (
"io/ioutil"
"testing"
)
func TestLogStabilization(t *testing.T) {
g := NewGraph()
// log(a+1)
x := NewVector(g, Float64, WithName("x"), WithShape(2))
p := Must(Add(x, onef64))
lp := Must(Log(p))
if lp.children[0] != x {
t.Error("Oops.")
ioutil.WriteFile("log(a+1).dot", []byte(lp.ToDot()), 0644)
}
// log(1+a)
p = Must(Add(onef64, x))
lp = Must(Log(p))
if lp.children[0] != x {
t.Error("Oops.")
ioutil.WriteFile("log(1+a).dot", []byte(lp.ToDot()), 0644)
}
//log(1-a)
m := Must(Sub(onef64, x))
lp = Must(Log(m))
if euo, ok := lp.children[0].op.(elemUnaryOp); !ok {
t.Error("Oops.")
} else {
if euo.unaryOpType() != negOpType {
t.Error("Expected Neg Op")
}
if lp.children[0].children[0] != x {
t.Error("Oops.")
}
}
if t.Failed() {
ioutil.WriteFile("log(1-a).dot", []byte(lp.ToDot()), 0644)
}
//log(a-1)
m = Must(Sub(x, onef64))
lp = Must(Log(m))
//TODO: surely there is a better way to test?
if lp.children[0] == x {
t.Error("Oops.")
}
}
func TestExpStabilization(t *testing.T) {
g := NewGraph()
x := NewVector(g, Float64, WithName("x"), WithShape(2))
e := Must(Exp(x))
s := Must(Sub(e, onef64))
if s.children[0] != x {
t.Error("oops")
}
if euo, ok := s.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != expm1OpType) {
t.Error("oops")
}
if t.Failed() {
ioutil.WriteFile("exp(a)-1.dot", []byte(s.ToDot()), 0644)
}
}
func TestLogSigmoidStabilization(t *testing.T) {
g := NewGraph()
stabilization = true
x := NewVector(g, Float64, WithName("x"), WithShape(2))
y := Must(Sigmoid(x))
WithName("y")(y)
logY := Must(Log(y))
WithName("log(sigmoid(x))")(logY)
if euo, ok := logY.op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) {
t.Error("Oops")
}
if euo, ok := logY.children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != softplusOpType) {
t.Error("Oops2")
}
if euo, ok := logY.children[0].children[0].op.(elemUnaryOp); !ok || (ok && euo.unaryOpType() != negOpType) {
t.Error("Oops3")
}
if logY.children[0].children[0].children[0] != x {
t.Errorf("Oops4: %v", logY.children[0].children[0].children[0].Name())
}
if t.Failed() {
ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
ioutil.WriteFile("logY.dot", []byte(logY.ToDot()), 0644)
}
}