Skip to content

Commit 58daa35

Browse files
authored
Merge pull request #20 from mschrimpf/master
copy_param, adjustment (output_padding), groups for ConvTranspose{2,3}d
2 parents 3efe68a + 74e6469 commit 58daa35

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

convert_torch.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ def lua_recursive_model(module,seq):
110110
n = Lambda(lambda x: x) # do nothing
111111
add_submodule(seq,n)
112112
elif name == 'SpatialFullConvolution':
113-
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))
113+
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))
114+
copy_param(m,n)
114115
add_submodule(seq,n)
115116
elif name == 'VolumetricFullConvolution':
116-
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))
117+
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)
118+
copy_param(m,n)
117119
add_submodule(seq, n)
118120
elif name == 'SpatialReplicationPadding':
119121
n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
@@ -195,11 +197,11 @@ def lua_recursive_source(module):
195197
elif name == 'Identity':
196198
s += ['Lambda(lambda x: x), # Identity']
197199
elif name == 'SpatialFullConvolution':
198-
s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
199-
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))]
200+
s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane,
201+
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))]
200202
elif name == 'VolumetricFullConvolution':
201-
s += ['nn.ConvTranspose3d({},{},{},{},{})'.format(m.nInputPlane,
202-
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))]
203+
s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane,
204+
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)]
203205
elif name == 'SpatialReplicationPadding':
204206
s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
205207
elif name == 'SpatialReflectionPadding':
@@ -247,7 +249,7 @@ def simplify_source(s):
247249
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
248250
s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
249251
s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
250-
252+
251253
s = map(lambda x: '{},\n'.format(x),s)
252254
s = map(lambda x: x[1:],s)
253255
s = reduce(lambda x,y: x+y, s)

0 commit comments

Comments
 (0)