Skip to content

Commit 74e6469

Browse files
committed
copy_param, adjustment (output_padding), groups for ConvTranspose{2,3}d
1 parent 4083bb0 commit 74e6469

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

convert_torch.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@ def lua_recursive_model(module,seq):
108108
n = Lambda(lambda x: x) # do nothing
109109
add_submodule(seq,n)
110110
elif name == 'SpatialFullConvolution':
111-
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))
111+
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))
112+
copy_param(m,n)
112113
add_submodule(seq,n)
113114
elif name == 'VolumetricFullConvolution':
114-
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))
115+
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)
116+
copy_param(m,n)
115117
add_submodule(seq, n)
116118
elif name == 'SpatialReplicationPadding':
117119
n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
@@ -193,11 +195,11 @@ def lua_recursive_source(module):
193195
elif name == 'Identity':
194196
s += ['Lambda(lambda x: x), # Identity']
195197
elif name == 'SpatialFullConvolution':
196-
s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
197-
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))]
198+
s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane,
199+
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))]
198200
elif name == 'VolumetricFullConvolution':
199-
s += ['nn.ConvTranspose3d({},{},{},{},{})'.format(m.nInputPlane,
200-
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH))]
201+
s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane,
202+
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)]
201203
elif name == 'SpatialReplicationPadding':
202204
s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
203205
elif name == 'SpatialReflectionPadding':
@@ -245,7 +247,7 @@ def simplify_source(s):
245247
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
246248
s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
247249
s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
248-
250+
249251
s = map(lambda x: '{},\n'.format(x),s)
250252
s = map(lambda x: x[1:],s)
251253
s = reduce(lambda x,y: x+y, s)

0 commit comments

Comments
 (0)