forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtypeSystem.go
144 lines (125 loc) · 2.98 KB
/
typeSystem.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package gorgonia
import (
"github.com/chewxy/hm"
"github.com/pkg/errors"
)
// inferType infers the type of the expression
func inferType(expr interface{}) (retVal hm.Type, err error) {
switch e := expr.(type) {
case *Node:
if e.isInput() || e.isConstant() {
// Var (and Let const)
return e.t, nil
}
// stop the recursive inference early - if the node already has a type, return it
if e.t != nil {
return e.t, nil
}
return inferNodeType(e.op, e.children...)
case Op:
return e.Type(), nil
case float32:
return Float32, nil
case float64:
return Float64, nil
case int:
return Int, nil
case int64:
return Int64, nil
case int32:
return Int32, nil
case bool:
return Bool, nil
default:
err = errors.Errorf(nyiTypeFail, "inferType", expr)
return
}
}
// Instead of using hm's Infer function, since all the nodes are pretty much hm.Apply, we write our own.
func inferNodeType(op Op, children ...*Node) (retVal hm.Type, err error) {
fnType := op.Type()
if fnt, ok := fnType.(*hm.FunctionType); ok {
defer hm.ReturnFnType(fnt)
}
argTypes := make(hm.Types, len(children)+1)
for i, child := range children {
if argTypes[i], err = inferType(child); err != nil {
return nil, errors.Wrapf(err, "Failed to infer type of %v", child)
}
}
b := hm.TypeVariable('b')
argTypes[len(argTypes)-1] = b
fn := hm.NewFnType(argTypes...)
defer hm.ReturnFnType(fn)
// var t0 hm.Type
var sub hm.Subs
if sub, err = hm.Unify(fn, fnType); err != nil {
return nil, errors.Wrapf(err, "Unable to unify while inferring type of %v", op)
}
var ok bool
if retVal, ok = sub.Get(b); !ok {
return nil, errors.Errorf("Expected a replacement for %v", b)
}
// return pruneReturn(t0.(*hm.FunctionType).ReturnType()), nil
return retVal, nil
}
func isScalarType(t hm.Type) bool {
switch tt := t.(type) {
case Dtype:
return true
case TensorType:
if tt.d == 0 {
return true
}
return false
case hm.TypeVariable:
panic("Type Variable is a type that is not yet known.")
default:
panic("Unhandled type")
}
}
func dtypeOf(t hm.Type) (retVal Dtype, err error) {
switch p := t.(type) {
case Dtype:
retVal = p
case TensorType:
return dtypeOf(p.of)
case hm.TypeVariable:
err = errors.Errorf("instance %v does not have a dtype", p)
default:
err = errors.Errorf(nyiFail, "dtypeOf", p)
return
}
return
}
// DEPRECATED
/*
func runtimeTypeCheck(expected, got hm.Types) (of Dtype, err error) {
if len(expected) != len(got) {
err = NewError(RuntimeError, "Input length mismatch")
return
}
if of, err = dtypeOf(expected[0]); err != nil {
return
}
for i, e := range expected {
g := got[i]
if !e.Eq(g) {
err = NewError(RuntimeError, "Expected input[%d] to be %v. Got %v instead", i, e, got[i])
return
}
if i > 0 {
var gdt Dtype
if gdt, err = dtypeOf(g); err == nil {
if gdt != of {
err = NewError(RuntimeError, "Different dtypes encountered... Expected %v. Got %v instead", of, gdt)
return
}
} else {
return
}
}
}
return
}
*/