@@ -72,6 +72,62 @@ models.siamese = function()
72
72
return m , input
73
73
end
74
74
75
+ models .siamese_parallel = function ()
76
+ local fSize = {1 , 32 , 64 }
77
+ local featuresOut = 128
78
+
79
+ local desc = nn .Sequential ()
80
+ desc :add (nn .Reshape (1 ,64 ,64 ))
81
+ desc :add (nn .SpatialAveragePooling (2 ,2 ,2 ,2 ))
82
+ desc :add (nn .SpatialConvolution (fSize [1 ], fSize [2 ], 7 ,7 ))
83
+ desc :add (nn .ReLU ())
84
+ desc :add (nn .SpatialMaxPooling (2 ,2 ,2 ,2 ))
85
+ desc :add (nn .SpatialConvolution (fSize [2 ], fSize [3 ], 6 ,6 ))
86
+ desc :add (nn .ReLU ())
87
+ desc :add (nn .View (- 1 ):setNumInputDims (3 ))
88
+ desc :add (nn .Linear (4096 , 128 ))
89
+ desc :add (nn .Contiguous ())
90
+
91
+ local siamese = nn .Parallel (2 ,2 )
92
+ local siam = desc :clone ()
93
+ desc :share (siam , ' weight' , ' bias' , ' gradWeight' , ' gradBias' )
94
+ siamese :add (desc )
95
+ siamese :add (siam )
96
+
97
+ local top = nn .Sequential ()
98
+ top :add (nn .Linear (featuresOut * 2 , featuresOut * 2 ))
99
+ top :add (nn .ReLU ())
100
+ top :add (nn .Linear (featuresOut * 2 , 1 ))
101
+
102
+ local model = nn .Sequential ():add (siamese ):add (top )
103
+
104
+ local input = torch .rand (1 ,2 ,64 ,64 )
105
+
106
+ return model , input
107
+ end
108
+
109
+ models .basic_parallel_middle = function ()
110
+ local model = nn .Sequential ():add (nn .Linear (2 ,2 ))
111
+ local prl = nn .Parallel (2 ,1 )
112
+ prl :add (nn .Linear (2 ,2 ))
113
+ prl :add (nn .Linear (2 ,2 ))
114
+ model :add (prl )
115
+ local input = torch .rand (2 ,2 )
116
+ return model , input
117
+ end
118
+
119
+ models .basic_splitTable = function ()
120
+ local model = nn .Sequential ():add (nn .Linear (2 ,2 ))
121
+ model :add (nn .SplitTable (2 ))
122
+ local prl = nn .ParallelTable ()
123
+ prl :add (nn .ReLU ())
124
+ prl :add (nn .Sigmoid ())
125
+ model :add (prl )
126
+ model :add (nn .JoinTable (1 ))
127
+ local input = torch .rand (2 ,2 )
128
+ return model , input
129
+ end
130
+
75
131
models .basic_concat = function ()
76
132
local m = nn .Sequential ()
77
133
local cat = nn .ConcatTable ()
0 commit comments