@@ -36,11 +36,29 @@ def add_submodule(seq, *args):
3636
3737class Convertor (object ):
3838
39- def __init__ (self ):
39+ def __init__ (self , model ):
4040 self .prefix_code = []
4141 self .t2pt_names = dict ()
4242 self .t2pt_layers = dict ()
4343
44+ def search_max_unpool (model ):
45+ modules = []
46+ modules .extend (model .modules )
47+ containers = ['Sequential' , 'Concat' ]
48+
49+ while modules :
50+ m = modules .pop ()
51+ name = type (m ).__name__
52+ if name in containers :
53+ modules .extend (m .modules )
54+
55+ if name == 'SpatialMaxUnpooling' :
56+ return True
57+
58+ return False
59+
60+ self .have_max_unpool = search_max_unpool (model )
61+
4462 def lua_recursive_model (self , module , seq ):
4563 for m in module .modules :
4664 name = type (m ).__name__
@@ -69,9 +87,11 @@ def lua_recursive_model(self, module, seq):
6987 n = nn .Sigmoid ()
7088 add_submodule (seq , n )
7189 elif name == 'SpatialMaxPooling' :
72- # n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
73- n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
74- self .t2pt_layers [m ] = n
90+ if not self .have_max_unpool :
91+ n = nn .MaxPool2d ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), ceil_mode = m .ceil_mode )
92+ else :
93+ n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
94+ self .t2pt_layers [m ] = n
7595 add_submodule (seq , n )
7696 elif name == 'SpatialMaxUnpooling' :
7797 if m .pooling in self .t2pt_layers :
@@ -164,30 +184,33 @@ def lua_recursive_source(self, module):
164184
165185 if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM' :
166186 if not hasattr (m , 'groups' ) or m .groups is None : m .groups = 1
167- s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d' .format (m .nInputPlane ,
168- m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups , m .bias is not None )]
187+ s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d' .format (m .nInputPlane ,
188+ m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups ,
189+ m .bias is not None )]
169190 elif name == 'SpatialBatchNormalization' :
170- s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
191+ s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
171192 elif name == 'VolumetricBatchNormalization' :
172193 s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
173194 elif name == 'ReLU' :
174195 s += ['nn.ReLU()' ]
175196 elif name == 'Sigmoid' :
176197 s += ['nn.Sigmoid()' ]
177198 elif name == 'SpatialMaxPooling' :
178- # s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
179- suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
180- name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
181- s += [name ]
182- self .t2pt_names [m ] = name
183- self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
199+ if not self .have_max_unpool :
200+ s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
201+ else :
202+ suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
203+ name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
204+ s += [name ]
205+ self .t2pt_names [m ] = name
206+ self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
184207 elif name == 'SpatialMaxUnpooling' :
185208 if m .pooling in self .t2pt_names :
186209 s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling' .format (self .t2pt_names [m .pooling ])]
187210 else :
188211 s += ['# ' + name + ' Not Implement (can\' t find corresponding SpatialMaxUnpooling,\n ' ]
189212 elif name == 'SpatialAveragePooling' :
190- s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
213+ s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
191214 elif name == 'SpatialUpSamplingNearest' :
192215 s += ['nn.UpsamplingNearest2d(scale_factor={})' .format (m .scale_factor )]
193216 elif name == 'View' :
@@ -197,7 +220,7 @@ def lua_recursive_source(self, module):
197220 elif name == 'Linear' :
198221 s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
199222 s2 = 'nn.Linear({},{},bias={})' .format (m .weight .size (1 ), m .weight .size (0 ), (m .bias is not None ))
200- s += ['nn.Sequential({},{}),#Linear' .format (s1 , s2 )]
223+ s += ['nn.Sequential({}, {}), #Linear' .format (s1 , s2 )]
201224 elif name == 'Dropout' :
202225 s += ['nn.Dropout({})' .format (m .p )]
203226 elif name == 'SoftMax' :
@@ -245,20 +268,20 @@ def lua_recursive_source(self, module):
245268
246269 @staticmethod
247270 def simplify_source (s ):
248- s = map (lambda x : x .replace (',(1, 1),(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
249- s = map (lambda x : x .replace (',(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
250- s = map (lambda x : x .replace (',1,1, bias=True),#Conv2d' , ')' ), s )
251- s = map (lambda x : x .replace (',bias=True),#Conv2d' , ')' ), s )
252- s = map (lambda x : x .replace ('),#Conv2d' , ')' ), s )
253- s = map (lambda x : x .replace (',1e-05,0.1,True),#BatchNorm2d' , ')' ), s )
254- s = map (lambda x : x .replace ('),#BatchNorm2d' , ')' ), s )
255- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#MaxPool2d' , ')' ), s )
256- s = map (lambda x : x .replace (',ceil_mode=False),#MaxPool2d' , ')' ), s )
257- s = map (lambda x : x .replace ('),#MaxPool2d' , ')' ), s )
258- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#AvgPool2d' , ')' ), s )
259- s = map (lambda x : x .replace (',ceil_mode=False),#AvgPool2d' , ')' ), s )
260- s = map (lambda x : x .replace (',bias=True)),#Linear' , ')), # Linear' ), s )
261- s = map (lambda x : x .replace (')),#Linear' , ')), # Linear' ), s )
271+ s = map (lambda x : x .replace (', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
272+ s = map (lambda x : x .replace (', (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
273+ s = map (lambda x : x .replace (', 1, 1, bias=True), #Conv2d' , ')' ), s )
274+ s = map (lambda x : x .replace (', bias=True), #Conv2d' , ')' ), s )
275+ s = map (lambda x : x .replace ('), #Conv2d' , ')' ), s )
276+ s = map (lambda x : x .replace (', 1e-05, 0.1, True), #BatchNorm2d' , ')' ), s )
277+ s = map (lambda x : x .replace ('), #BatchNorm2d' , ')' ), s )
278+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #MaxPool2d' , ')' ), s )
279+ s = map (lambda x : x .replace (', ceil_mode=False), #MaxPool2d' , ')' ), s )
280+ s = map (lambda x : x .replace ('), #MaxPool2d' , ')' ), s )
281+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #AvgPool2d' , ')' ), s )
282+ s = map (lambda x : x .replace (', ceil_mode=False), #AvgPool2d' , ')' ), s )
283+ s = map (lambda x : x .replace (', bias=True)), #Linear' , ')), # Linear' ), s )
284+ s = map (lambda x : x .replace (')), #Linear' , ')), # Linear' ), s )
262285
263286 s = map (lambda x : '{},\n ' .format (x ), s )
264287 s = map (lambda x : x [1 :], s )
@@ -272,17 +295,19 @@ def torch_to_pytorch(t7_filename, outputname=None):
272295 model = model .model
273296 model .gradInput = None
274297
275- cvt = Convertor ()
276- slist = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
277- s = cvt .simplify_source (slist )
298+ cvt = Convertor (model )
299+ s = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
300+ s = cvt .simplify_source (s )
278301
279302 varname = os .path .basename (t7_filename ).replace ('.t7' , '' ).replace ('.' , '_' ).replace ('-' , '_' )
280303
281304 with open ("header.py" ) as f :
282305 header = f .read ()
283306 s = '{}\n {}\n \n {} = {}' .format (header , '\n ' .join (cvt .prefix_code ), varname , s [:- 2 ])
284307
285- if outputname is None : outputname = varname
308+ if outputname is None :
309+ outputname = varname
310+
286311 with open (outputname + '.py' , "w" ) as pyfile :
287312 pyfile .write (s )
288313
@@ -294,7 +319,7 @@ def torch_to_pytorch(t7_filename, outputname=None):
294319if __name__ == '__main__' :
295320 parser = argparse .ArgumentParser (description = 'Convert torch t7 model to pytorch' )
296321 parser .add_argument ('--model' , '-m' , type = str , required = True , help = 'torch model file in t7 format' )
297- parser .add_argument ('--output' , '-o' , type = str , default = None , help = 'output file name prefix, xxx.py xxx.pth' )
322+ parser .add_argument ('--output' , '-o' , type = str , default = '/tmp/model' , help = 'output file name prefix, xxx.py xxx.pth' )
298323 args = parser .parse_args ()
299324
300325 torch_to_pytorch (args .model , args .output )
0 commit comments