Skip to content

Commit 3efe68a

Browse files
authored
Merge pull request #19 from MicaelCarvalho/master
Legacy fix for newer versions
2 parents dd6b8f4 + ce70d1a commit 3efe68a

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

convert_torch.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from __future__ import print_function
2-
import argparse
2+
3+
import os
4+
import math
35
import torch
6+
import argparse
7+
import numpy as np
48
import torch.nn as nn
5-
import torch.nn.functional as F
69
import torch.optim as optim
7-
from torch.autograd import Variable
8-
from torch.utils.serialization import load_lua
10+
import torch.legacy.nn as lnn
11+
import torch.nn.functional as F
912

10-
import numpy as np
11-
import os
12-
import math
1313
from functools import reduce
14+
from torch.autograd import Variable
15+
from torch.utils.serialization import load_lua
1416

1517
class LambdaBase(nn.Sequential):
1618
def __init__(self, fn, *args):
@@ -126,7 +128,7 @@ def lua_recursive_model(module,seq):
126128
n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
127129
add_submodule(seq,n)
128130
elif name == 'SpatialCrossMapLRN':
129-
lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
131+
lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
130132
n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
131133
add_submodule(seq,n)
132134
elif name == 'Sequential':
@@ -207,7 +209,7 @@ def lua_recursive_source(module):
207209
elif name == 'Narrow':
208210
s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
209211
elif name == 'SpatialCrossMapLRN':
210-
lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
212+
lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
211213
s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
212214

213215
elif name == 'Sequential':
@@ -255,13 +257,15 @@ def torch_to_pytorch(t7_filename,outputname=None):
255257
model = load_lua(t7_filename,unknown_classes=True)
256258
if type(model).__name__=='hashable_uniq_dict': model=model.model
257259
model.gradInput = None
258-
slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
260+
slist = lua_recursive_source(lnn.Sequential().add(model))
259261
s = simplify_source(slist)
260262
header = '''
261263
import torch
262264
import torch.nn as nn
263-
from torch.autograd import Variable
265+
import torch.legacy.nn as lnn
266+
264267
from functools import reduce
268+
from torch.autograd import Variable
265269
266270
class LambdaBase(nn.Sequential):
267271
def __init__(self, fn, *args):

0 commit comments

Comments
 (0)