forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn.go
126 lines (102 loc) · 3 KB
/
nn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package gorgonia
import "github.com/pkg/errors"
// BinaryXent is a convenience function for doing binary crossentropy stuff.
// The formula is as below:
// -(y * logprob) + (1-y)(1-logprob)
func BinaryXent(output, target *Node) (retVal *Node, err error) {
var one *Node
var logO, omt, omo, tLogO *Node
// which constant one to use?
var dt Dtype
if dt, err = dtypeOf(output.t); err != nil {
return nil, errors.Wrapf(err, dtypeExtractionFail, output.t)
}
switch dt {
case Float64:
one = onef64
case Float32:
one = onef32
default:
return nil, errors.Errorf(nyiFail, "BinaryXEnt", dt)
}
if logO, err = Log(output); err != nil {
return nil, errors.Wrap(err, operationError)
}
if omt, err = Sub(one, target); err != nil {
return nil, errors.Wrap(err, operationError)
}
if omo, err = Sub(one, output); err != nil {
return nil, errors.Wrap(err, operationError)
}
if tLogO, err = HadamardProd(target, logO); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = Log(omo); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = HadamardProd(omt, retVal); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = Add(tLogO, retVal); err != nil {
return nil, errors.Wrap(err, operationError)
}
return Neg(retVal)
}
// Dropout is a convenience function to implement dropout.
// It uses randomly zeroes out a *Tensor with a probabilty drawn from
// a uniform distribution
func Dropout(x *Node, prob float64) (retVal *Node, err error) {
if prob == 0.0 {
return x, nil
}
var dt Dtype
if dt, err = dtypeOf(x.t); err != nil {
return nil, errors.Wrap(err, dtypeOfFail)
}
var opp, pr Value // opp = 1 per p
switch dt {
case Float64:
opp, _ = anyToScalar(1.0 / prob)
pr, _ = anyToScalar(prob)
case Float32:
opp, _ = anyToScalar(float32(1.0 / prob))
pr, _ = anyToScalar(float32(prob))
default:
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
}
p := NewConstant(pr)
c := NewConstant(opp)
m := UniformRandomNode(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed")
}
if retVal, err = HadamardProd(x, retVal); err != nil {
return nil, errors.Wrap(err, mulFail)
}
return HadamardDiv(retVal, c)
}
// Rectify is a convenience function for creating rectified linear units activation functions.
// This function uses >=, which is the canonical version. If you want to use >, you can create
// your own by just following this.
func Rectify(x *Node) (retVal *Node, err error) {
var zero *Node
var dt Dtype
// which zero to use?
if dt, err = dtypeOf(x.t); err != nil {
return nil, errors.Wrap(err, dtypeOfFail)
}
switch dt {
case Float64:
zero = zerof64
case Float32:
zero = zerof32
default:
return nil, errors.Errorf(nyiFail, "ReLu", dt)
}
cmp := newElemBinOp(gteOpType, x, zero)
cmp.retSame = true
if retVal, err = applyOp(cmp, x); err != nil {
return nil, errors.Wrap(err, applyOpFail)
}
return HadamardProd(x, retVal)
}