From d6064749e5f21f0d150bccd2709d30d64a5fdd63 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Tue, 8 Mar 2016 17:43:55 +0100 Subject: [PATCH 1/2] torch adaptation start --- neural_style.lua | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/neural_style.lua b/neural_style.lua index 2d7611a..6e0ef05 100644 --- a/neural_style.lua +++ b/neural_style.lua @@ -3,8 +3,6 @@ require 'nn' require 'image' require 'optim' -require 'loadcaffe' - -------------------------------------------------------------------------------- local cmd = torch.CmdLine() @@ -69,8 +67,16 @@ local function main(params) end local loadcaffe_backend = params.backend + local cnn if params.backend == 'clnn' then loadcaffe_backend = 'nn' end - local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):float() + local is_caffemodel = params.model_file:find'caffemodel' + if is_caffemodel then + require 'loadcaffe' + cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend) + else + cnn = torch.load(params.model_file):unpack() + end + cnn:float() if params.gpu >= 0 then if params.backend ~= 'clnn' then cnn:cuda() @@ -81,7 +87,12 @@ local function main(params) local content_image = image.load(params.content_image, 3) content_image = image.scale(content_image, params.image_size, 'bilinear') - local content_image_caffe = preprocess(content_image):float() + + -- make these global + preprocess = is_caffemodel and preprocess_caffe or preprocess_torch + deprocess = is_caffemodel and deprocess_caffe or deprocess_torch + print(cnn.transform) + local content_image_caffe = preprocess(content_image, cnn.transform):float() local style_size = math.ceil(params.style_scale * params.image_size) local style_image_list = params.style_image:split(',') @@ -89,7 +100,7 @@ local function main(params) for _, img_path in ipairs(style_image_list) do local img = image.load(img_path, 3) img = image.scale(img, style_size, 'bilinear') - local img_caffe = preprocess(img):float() + local img_caffe = preprocess(img, cnn.transform):float() table.insert(style_images_caffe, img_caffe) end @@ -262,6 +273,7 @@ local function main(params) -- Run it through the network once to get the proper size for the gradient -- All the gradients will come from the extra loss modules, so we just pass -- zeros into the top of the net on the backward pass. + require'fb.debugger'.enter() local y = net:forward(img) local dy = img.new(#y):zero() @@ -294,11 +306,11 @@ local function main(params) end end - local function maybe_save(t) + local function maybe_save(t, img_mean) local should_save = params.save_iter > 0 and t % params.save_iter == 0 should_save = should_save or t == params.num_iterations if should_save then - local disp = deprocess(img:double()) + local disp = deprocess(img:double(), img_mean) disp = image.minmax{tensor=disp, min=0, max=1} local filename = build_filename(params.output_image, t) if t == params.num_iterations then @@ -326,7 +338,7 @@ local function main(params) loss = loss + mod.loss end maybe_print(num_calls, loss) - maybe_save(num_calls) + maybe_save(num_calls, cnn.transform) collectgarbage() -- optim.lbfgs expects a vector for gradients @@ -357,7 +369,7 @@ end -- Preprocess an image before passing it to a Caffe model. -- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, -- and subtract the mean pixel. -function preprocess(img) +function preprocess_caffe(img) local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) local perm = torch.LongTensor{3, 2, 1} img = img:index(1, perm):mul(256.0) @@ -366,9 +378,15 @@ function preprocess(img) return img end +function preprocess_torch(img, img_mean) + local im = img:clone() + for i=1,3 do im[i]:add(-img_mean.mean[i]):div(img_mean.std[i]) end + return im +end + -- Undo the above preprocessing. -function deprocess(img) +function deprocess_caffe(img) local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) img = img + mean_pixel @@ -377,6 +395,11 @@ function deprocess(img) return img end +function deprocess_torch(img, img_mean) + local im = img:clone() + for i=1,3 do im[i]:mul(img_mean.std[i]):add(img_mean.mean[i]) end + return im +end -- Define an nn Module to compute content loss in-place local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') From bb9cfcb6d73d3c94296f53fd866e8f1d4918a572 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Sun, 13 Mar 2016 00:23:46 +0100 Subject: [PATCH 2/2] fixes --- neural_style.lua | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/neural_style.lua b/neural_style.lua index 6e0ef05..371cec4 100644 --- a/neural_style.lua +++ b/neural_style.lua @@ -74,7 +74,11 @@ local function main(params) require 'loadcaffe' cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend) else - cnn = torch.load(params.model_file):unpack() + cnn = torch.load(params.model_file) + if cnn.unpack then cnn = cnn:unpack() end + if params.backend == 'cudnn' then + cudnn.convert(cnn, cudnn) + end end cnn:float() if params.gpu >= 0 then @@ -163,7 +167,7 @@ local function main(params) for i = 1, #cnn do if next_content_idx <= #content_layers or next_style_idx <= #style_layers then local layer = cnn:get(i) - local name = layer.name + local name = is_caffemodel and layer.name or tostring(i) local layer_type = torch.type(layer) local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling') if is_pooling and params.pooling == 'avg' then @@ -239,7 +243,7 @@ local function main(params) end -- We don't need the base CNN anymore, so clean it up to save memory. - cnn = nil + -- cnn = nil for i=1,#net.modules do local module = net.modules[i] if torch.type(module) == 'nn.SpatialConvolutionMM' then @@ -273,7 +277,6 @@ local function main(params) -- Run it through the network once to get the proper size for the gradient -- All the gradients will come from the extra loss modules, so we just pass -- zeros into the top of the net on the backward pass. - require'fb.debugger'.enter() local y = net:forward(img) local dy = img.new(#y):zero()