|
1 | 1 | from __future__ import print_function
|
2 |
| -import argparse |
| 2 | + |
| 3 | +import os |
| 4 | +import math |
3 | 5 | import torch
|
| 6 | +import argparse |
| 7 | +import numpy as np |
4 | 8 | import torch.nn as nn
|
5 |
| -import torch.nn.functional as F |
6 | 9 | import torch.optim as optim
|
7 |
| -from torch.autograd import Variable |
8 |
| -from torch.utils.serialization import load_lua |
| 10 | +import torch.legacy.nn as lnn |
| 11 | +import torch.nn.functional as F |
9 | 12 |
|
10 |
| -import numpy as np |
11 |
| -import os |
12 |
| -import math |
13 | 13 | from functools import reduce
|
| 14 | +from torch.autograd import Variable |
| 15 | +from torch.utils.serialization import load_lua |
14 | 16 |
|
15 | 17 | class LambdaBase(nn.Sequential):
|
16 | 18 | def __init__(self, fn, *args):
|
@@ -126,7 +128,7 @@ def lua_recursive_model(module,seq):
|
126 | 128 | n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
|
127 | 129 | add_submodule(seq,n)
|
128 | 130 | elif name == 'SpatialCrossMapLRN':
|
129 |
| - lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) |
| 131 | + lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) |
130 | 132 | n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
|
131 | 133 | add_submodule(seq,n)
|
132 | 134 | elif name == 'Sequential':
|
@@ -207,7 +209,7 @@ def lua_recursive_source(module):
|
207 | 209 | elif name == 'Narrow':
|
208 | 210 | s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
|
209 | 211 | elif name == 'SpatialCrossMapLRN':
|
210 |
| - lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) |
| 212 | + lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) |
211 | 213 | s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
|
212 | 214 |
|
213 | 215 | elif name == 'Sequential':
|
@@ -255,13 +257,15 @@ def torch_to_pytorch(t7_filename,outputname=None):
|
255 | 257 | model = load_lua(t7_filename,unknown_classes=True)
|
256 | 258 | if type(model).__name__=='hashable_uniq_dict': model=model.model
|
257 | 259 | model.gradInput = None
|
258 |
| - slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model)) |
| 260 | + slist = lua_recursive_source(lnn.Sequential().add(model)) |
259 | 261 | s = simplify_source(slist)
|
260 | 262 | header = '''
|
261 | 263 | import torch
|
262 | 264 | import torch.nn as nn
|
263 |
| -from torch.autograd import Variable |
| 265 | +import torch.legacy.nn as lnn |
| 266 | +
|
264 | 267 | from functools import reduce
|
| 268 | +from torch.autograd import Variable |
265 | 269 |
|
266 | 270 | class LambdaBase(nn.Sequential):
|
267 | 271 | def __init__(self, fn, *args):
|
|
0 commit comments