@@ -72,6 +72,62 @@ models.siamese = function()
7272 return m , input
7373end
7474
75+ models .siamese_parallel = function ()
76+ local fSize = {1 , 32 , 64 }
77+ local featuresOut = 128
78+
79+ local desc = nn .Sequential ()
80+ desc :add (nn .Reshape (1 ,64 ,64 ))
81+ desc :add (nn .SpatialAveragePooling (2 ,2 ,2 ,2 ))
82+ desc :add (nn .SpatialConvolution (fSize [1 ], fSize [2 ], 7 ,7 ))
83+ desc :add (nn .ReLU ())
84+ desc :add (nn .SpatialMaxPooling (2 ,2 ,2 ,2 ))
85+ desc :add (nn .SpatialConvolution (fSize [2 ], fSize [3 ], 6 ,6 ))
86+ desc :add (nn .ReLU ())
87+ desc :add (nn .View (- 1 ):setNumInputDims (3 ))
88+ desc :add (nn .Linear (4096 , 128 ))
89+ desc :add (nn .Contiguous ())
90+
91+ local siamese = nn .Parallel (2 ,2 )
92+ local siam = desc :clone ()
93+ desc :share (siam , ' weight' , ' bias' , ' gradWeight' , ' gradBias' )
94+ siamese :add (desc )
95+ siamese :add (siam )
96+
97+ local top = nn .Sequential ()
98+ top :add (nn .Linear (featuresOut * 2 , featuresOut * 2 ))
99+ top :add (nn .ReLU ())
100+ top :add (nn .Linear (featuresOut * 2 , 1 ))
101+
102+ local model = nn .Sequential ():add (siamese ):add (top )
103+
104+ local input = torch .rand (1 ,2 ,64 ,64 )
105+
106+ return model , input
107+ end
108+
109+ models .basic_parallel_middle = function ()
110+ local model = nn .Sequential ():add (nn .Linear (2 ,2 ))
111+ local prl = nn .Parallel (2 ,1 )
112+ prl :add (nn .Linear (2 ,2 ))
113+ prl :add (nn .Linear (2 ,2 ))
114+ model :add (prl )
115+ local input = torch .rand (2 ,2 )
116+ return model , input
117+ end
118+
119+ models .basic_splitTable = function ()
120+ local model = nn .Sequential ():add (nn .Linear (2 ,2 ))
121+ model :add (nn .SplitTable (2 ))
122+ local prl = nn .ParallelTable ()
123+ prl :add (nn .ReLU ())
124+ prl :add (nn .Sigmoid ())
125+ model :add (prl )
126+ model :add (nn .JoinTable (1 ))
127+ local input = torch .rand (2 ,2 )
128+ return model , input
129+ end
130+
75131models .basic_concat = function ()
76132 local m = nn .Sequential ()
77133 local cat = nn .ConcatTable ()
0 commit comments