Skip to content

Commit 4083bb0

Browse files
committed
add Sigmoid, Volumetric{FullConvolution, BatchNormalization}
1 parent 56e7630 commit 4083bb0

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

convert_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,16 @@ def lua_recursive_model(module,seq):
6565
n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
6666
copy_param(m,n)
6767
add_submodule(seq,n)
68+
elif name == 'VolumetricBatchNormalization':
69+
n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
70+
copy_param(m, n)
71+
add_submodule(seq, n)
6872
elif name == 'ReLU':
6973
n = nn.ReLU()
7074
add_submodule(seq,n)
75+
elif name == 'Sigmoid':
76+
n = nn.Sigmoid()
77+
add_submodule(seq,n)
7178
elif name == 'SpatialMaxPooling':
7279
n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
7380
add_submodule(seq,n)
@@ -103,6 +110,9 @@ def lua_recursive_model(module,seq):
103110
elif name == 'SpatialFullConvolution':
104111
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))
105112
add_submodule(seq,n)
113+
elif name == 'VolumetricFullConvolution':
114+
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))
115+
add_submodule(seq, n)
106116
elif name == 'SpatialReplicationPadding':
107117
n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
108118
add_submodule(seq,n)
@@ -156,8 +166,12 @@ def lua_recursive_source(module):
156166
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
157167
elif name == 'SpatialBatchNormalization':
158168
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
169+
elif name == 'VolumetricBatchNormalization':
170+
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
159171
elif name == 'ReLU':
160172
s += ['nn.ReLU()']
173+
elif name == 'Sigmoid':
174+
s += ['nn.Sigmoid()']
161175
elif name == 'SpatialMaxPooling':
162176
s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
163177
elif name == 'SpatialAveragePooling':
@@ -181,6 +195,9 @@ def lua_recursive_source(module):
181195
elif name == 'SpatialFullConvolution':
182196
s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
183197
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))]
198+
elif name == 'VolumetricFullConvolution':
199+
s += ['nn.ConvTranspose3d({},{},{},{},{})'.format(m.nInputPlane,
200+
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))]
184201
elif name == 'SpatialReplicationPadding':
185202
s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
186203
elif name == 'SpatialReflectionPadding':

0 commit comments

Comments
 (0)