Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ require 'nn'
require 'image'
require 'optim'

require 'loadcaffe'

--------------------------------------------------------------------------------

local cmd = torch.CmdLine()
Expand Down Expand Up @@ -69,8 +67,20 @@ local function main(params)
end

local loadcaffe_backend = params.backend
local cnn
if params.backend == 'clnn' then loadcaffe_backend = 'nn' end
local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):float()
local is_caffemodel = params.model_file:find'caffemodel'
if is_caffemodel then
require 'loadcaffe'
cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend)
else
cnn = torch.load(params.model_file)
if cnn.unpack then cnn = cnn:unpack() end
if params.backend == 'cudnn' then
cudnn.convert(cnn, cudnn)
end
end
cnn:float()
if params.gpu >= 0 then
if params.backend ~= 'clnn' then
cnn:cuda()
Expand All @@ -81,15 +91,20 @@ local function main(params)

local content_image = image.load(params.content_image, 3)
content_image = image.scale(content_image, params.image_size, 'bilinear')
local content_image_caffe = preprocess(content_image):float()

-- make these global
preprocess = is_caffemodel and preprocess_caffe or preprocess_torch
deprocess = is_caffemodel and deprocess_caffe or deprocess_torch
print(cnn.transform)
local content_image_caffe = preprocess(content_image, cnn.transform):float()

local style_size = math.ceil(params.style_scale * params.image_size)
local style_image_list = params.style_image:split(',')
local style_images_caffe = {}
for _, img_path in ipairs(style_image_list) do
local img = image.load(img_path, 3)
img = image.scale(img, style_size, 'bilinear')
local img_caffe = preprocess(img):float()
local img_caffe = preprocess(img, cnn.transform):float()
table.insert(style_images_caffe, img_caffe)
end

Expand Down Expand Up @@ -152,7 +167,7 @@ local function main(params)
for i = 1, #cnn do
if next_content_idx <= #content_layers or next_style_idx <= #style_layers then
local layer = cnn:get(i)
local name = layer.name
local name = is_caffemodel and layer.name or tostring(i)
local layer_type = torch.type(layer)
local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling')
if is_pooling and params.pooling == 'avg' then
Expand Down Expand Up @@ -228,7 +243,7 @@ local function main(params)
end

-- We don't need the base CNN anymore, so clean it up to save memory.
cnn = nil
-- cnn = nil
for i=1,#net.modules do
local module = net.modules[i]
if torch.type(module) == 'nn.SpatialConvolutionMM' then
Expand Down Expand Up @@ -294,11 +309,11 @@ local function main(params)
end
end

local function maybe_save(t)
local function maybe_save(t, img_mean)
local should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save then
local disp = deprocess(img:double())
local disp = deprocess(img:double(), img_mean)
disp = image.minmax{tensor=disp, min=0, max=1}
local filename = build_filename(params.output_image, t)
if t == params.num_iterations then
Expand Down Expand Up @@ -326,7 +341,7 @@ local function main(params)
loss = loss + mod.loss
end
maybe_print(num_calls, loss)
maybe_save(num_calls)
maybe_save(num_calls, cnn.transform)

collectgarbage()
-- optim.lbfgs expects a vector for gradients
Expand Down Expand Up @@ -357,7 +372,7 @@ end
-- Preprocess an image before passing it to a Caffe model.
-- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
-- and subtract the mean pixel.
function preprocess(img)
function preprocess_caffe(img)
local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68})
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm):mul(256.0)
Expand All @@ -366,9 +381,15 @@ function preprocess(img)
return img
end

function preprocess_torch(img, img_mean)
local im = img:clone()
for i=1,3 do im[i]:add(-img_mean.mean[i]):div(img_mean.std[i]) end
return im
end


-- Undo the above preprocessing.
function deprocess(img)
function deprocess_caffe(img)
local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68})
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
img = img + mean_pixel
Expand All @@ -377,6 +398,11 @@ function deprocess(img)
return img
end

function deprocess_torch(img, img_mean)
local im = img:clone()
for i=1,3 do im[i]:mul(img_mean.std[i]):add(img_mean.mean[i]) end
return im
end

-- Define an nn Module to compute content loss in-place
local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module')
Expand Down