diff --git a/fast_neural_style/HistoLoss.lua b/fast_neural_style/HistoLoss.lua new file mode 100644 index 0000000..06fcf21 --- /dev/null +++ b/fast_neural_style/HistoLoss.lua @@ -0,0 +1,180 @@ +require 'torch' +require 'nn' + +local threads = require 'threads' + +local function makeCdfInv(img, bins) + -- things we'll need later. + local imgmin = img:min() + local imgmax = img:max()-img:min() + local cdfinv = torch.zeros(bins) + local cdfinvcount = torch.ones(bins) + local cdfsum = 0 + + local imgview = img:view(-1) + + -- calculate histogram + local hist = torch.histc(img, bins) + + -- calculate probability density function... + local pmf = hist:div(img:nElement()) + + -- ... and use that to generate a cumulative density function. + local cdf = pmf:apply(function(x) + cdfsum = cdfsum + x + return cdfsum + end + ) + + -- we then scale and floor the CDF for generating the inverse CDF + cdf:mul(bins-1):floor() + + -- and then generate the inverse cdf. + imgview:apply(function(x) + local y = math.floor(((x-imgmin)/(imgmax+1e-11))*(bins-1)+1) + y = cdf[y]+1 + cdfinv[y] = cdfinv[y] + x + cdfinvcount[y] = cdfinvcount[y] + 1 + end + ) + cdfinv:cdiv(cdfinvcount) + + -- to improve results, replace all unfilled inverse CDF bins with linear interpolated values. + cdfinv[bins] = cdfinv:max() + if math.ceil(cdfinv:max()) ~= 0 then + for i = 2, cdfinv:size()[1] do + local count = 1 + local temp1 = temp1 or cdfinv[i-1] + local temp2 = 0 + if cdfinv[i] == 0 then + while cdfinv[i-1+count] == 0 do + count = count + 1 + temp2 = cdfinv[i-1+count] + end + if count < 2 then + end + cdfinv[i] = temp1*(1/count)+temp2*(1-(1/count)) + else + temp1 = cdfinv[i] + end + end + end + + return cdfinv +end + +local function histoMatch(img, cdfinv, bins) + -- things we'll need later. + local imgmin = img:min() + local imgmax = img:max()-img:min() + local cdfsum = 0 + + local imgview = img:view(-1) + + -- calculate histogram + local hist = torch.histc(img, bins) + + -- calculate probability density function... + local pmf = hist:div(img:nElement()) + + -- ... and use that to generate a cumulative density function. + local cdf = pmf:apply(function(x) + cdfsum = cdfsum + x + return cdfsum + end + ) + -- finally, we use the generated CDF to match the histograms. + local function invert(img) + img = math.floor(((img-imgmin)/(imgmax+1e-11))*(bins-1)+1) + img = math.floor(cdf[img]*(bins-1)+1) + return cdfinv[img] + end + imgview:apply(invert) + + return(img) +end + +local HistoLoss, parent = torch.class('nn.HistoLoss', 'nn.Module') + +function HistoLoss:__init(strength, bins, n_threads) + parent.__init(self) + self.strength = strength + self.target = nil + self.loss = 0 + self.bins = bins + self.mode = 'none' + self.H = nil + self.crit = nn.MSECriterion() + self.crit.sizeAverage = true + self.threads = n_threads or 6 +end + +function HistoLoss:updateOutput(input) + -- since creating an opencl/CUDA kernel for this is non-trivial, + -- instead i've chosen to thread the fuck out of it. + local pool = threads.Threads(self.threads) + local bins_thread = self.bins + if self.mode == 'capture' then + self.target = torch.Tensor(input:size()[1], input:size()[2], self.bins) + for i = 1, input:size()[1] do + for j = 1, input:size()[2] do + local input_thread = input[i][j]:clone() + local target_thread = self.target[i][j]:clone() + pool:addjob( + function() + target_thread = makeCdfInv(input_thread, bins_thread) + return target_thread + end, + + function(target_thread) + self.target[i][j] = target_thread + end + ) + end + end + pool:synchronize() + pool:terminate() + elseif self.mode == 'loss' then + self.H = input:clone() + for i = 1, input:size()[1] do + for j = 1, input:size()[2] do + local target_thread = self.target[1][j]:clone() + local input_thread = input[i][j]:clone() + local H_thread = self.H[i][j]:clone() + pool:addjob( + function() + H_thread = histoMatch(input_thread, target_thread, bins_thread) + return H_thread + end, + function(H_thread) + self.H[i][j] = H_thread + end + ) + end + end + pool:synchronize() + pool:terminate() + self.loss = self.crit:forward(input, self.H) + self.loss = self.loss * self.strength + end + self.output = input + return self.output +end + +function HistoLoss:updateGradInput(input, gradOutput) + if self.mode == 'capture' or self.mode == 'none' then + self.gradInput = gradOutput + elseif self.mode == 'loss' then + self.gradInput = self.crit:backward(input, self.H) + self.gradInput:mul(self.strength) + self.gradInput:add(gradOutput) + end + return self.gradInput +end + +function HistoLoss:setMode(mode) + if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then + error(string.format('Invalid mode "%s"', mode)) + end + self.mode = mode +end diff --git a/fast_neural_style/PerceptualCriterion.lua b/fast_neural_style/PerceptualCriterion.lua index af6f604..e775044 100644 --- a/fast_neural_style/PerceptualCriterion.lua +++ b/fast_neural_style/PerceptualCriterion.lua @@ -4,6 +4,7 @@ require 'nn' require 'fast_neural_style.ContentLoss' require 'fast_neural_style.StyleLoss' require 'fast_neural_style.DeepDreamLoss' +require 'fast_neural_style.HistoLoss' local layer_utils = require 'fast_neural_style.layer_utils' @@ -27,12 +28,14 @@ Input: args is a table with the following keys: function crit:__init(args) args.content_layers = args.content_layers or {} args.style_layers = args.style_layers or {} + args.histo_layers = args.histo_layers or {} args.deepdream_layers = args.deepdream_layers or {} self.net = args.cnn self.net:evaluate() self.content_loss_layers = {} self.style_loss_layers = {} + self.histo_loss_layers = {} self.deepdream_loss_layers = {} -- Set up content loss layers @@ -50,6 +53,14 @@ function crit:__init(args) layer_utils.insert_after(self.net, layer_string, style_loss_layer) table.insert(self.style_loss_layers, style_loss_layer) end + + -- Set up histo loss layers + for i, layer_string in ipairs(args.histo_layers) do + local weight = args.histo_weights[i] + local histo_loss_layers = nn.HistoLoss(weight, args.histo_bins, args.histo_threads) + layer_utils.insert_after(self.net, layer_string, histo_loss_layers) + table.insert(self.histo_loss_layers, histo_loss_layers) + end -- Set up DeepDream layers for i, layer_string in ipairs(args.deepdream_layers) do @@ -75,10 +86,12 @@ function crit:setStyleTarget(target) for i, style_loss_layer in ipairs(self.style_loss_layers) do style_loss_layer:setMode('capture') end + for i, histo_loss_layer in ipairs(self.histo_loss_layers) do + histo_loss_layer:setMode('capture') + end self.net:forward(target) end - --[[ target: Tensor of shape (N, 3, H, W) giving pixels for content target images --]] @@ -86,6 +99,9 @@ function crit:setContentTarget(target) for i, style_loss_layer in ipairs(self.style_loss_layers) do style_loss_layer:setMode('none') end + for i, histo_loss_layer in ipairs(self.histo_loss_layers) do + histo_loss_layer:setMode('none') + end for i, content_loss_layer in ipairs(self.content_loss_layers) do content_loss_layer:setMode('capture') end @@ -106,6 +122,11 @@ function crit:setContentWeight(weight) end end +function crit:setHistoWeight(weight) + for i, histo_loss_layer in ipairs(self.histo_loss_layers) do + histo_loss_layer.strength = weight + end +end --[[ Inputs: @@ -119,7 +140,7 @@ function crit:updateOutput(input, target) self:setContentTarget(target.content_target) end if target.style_target then - self.setStyleTarget(target.style_target) + self:setStyleTarget(target.style_target) end -- Make sure to set all content and style loss layers to loss mode before @@ -130,6 +151,9 @@ function crit:updateOutput(input, target) for i, style_loss_layer in ipairs(self.style_loss_layers) do style_loss_layer:setMode('loss') end + for i, histo_loss_layer in ipairs(self.histo_loss_layers) do + histo_loss_layer:setMode('loss') + end local output = self.net:forward(input) @@ -141,6 +165,8 @@ function crit:updateOutput(input, target) self.content_losses = {} self.total_style_loss = 0 self.style_losses = {} + self.total_histo_loss = 0 + self.histo_losses = {} for i, content_loss_layer in ipairs(self.content_loss_layers) do self.total_content_loss = self.total_content_loss + content_loss_layer.loss table.insert(self.content_losses, content_loss_layer.loss) @@ -149,8 +175,12 @@ function crit:updateOutput(input, target) self.total_style_loss = self.total_style_loss + style_loss_layer.loss table.insert(self.style_losses, style_loss_layer.loss) end + for i, histo_loss_layer in ipairs(self.histo_loss_layers) do + self.total_histo_loss = self.total_histo_loss + histo_loss_layer.loss + table.insert(self.histo_losses, histo_loss_layer.loss) + end - self.output = self.total_style_loss + self.total_content_loss + self.output = self.total_style_loss + self.total_content_loss + self.total_histo_loss return self.output end diff --git a/slow_neural_style.lua b/slow_neural_style.lua index 03d0e93..6a27143 100644 --- a/slow_neural_style.lua +++ b/slow_neural_style.lua @@ -33,6 +33,10 @@ cmd:option('-content_layers', '16') cmd:option('-style_weights', '5.0') cmd:option('-style_layers', '4,9,16,23') cmd:option('-style_image_size', 512) +cmd:option('-histo_weights', '5.0') +cmd:option('-histo_layers', '2,21') +cmd:option('-histo_bins', 256) +cmd:option('-histo_threads', 4) -- Options for DeepDream cmd:option('-deepdream_layers', '') @@ -80,6 +84,8 @@ local function main() print(loss_net) local style_layers, style_weights = utils.parse_layers(opt.style_layers, opt.style_weights) + local histo_layers, histo_weights = + utils.parse_layers(opt.histo_layers, opt.histo_weights) local content_layers, content_weights = utils.parse_layers(opt.content_layers, opt.content_weights) local deepdream_layers, deepdream_weights = @@ -88,6 +94,10 @@ local function main() cnn = loss_net, style_layers = style_layers, style_weights = style_weights, + histo_layers = histo_layers, + histo_weights = histo_weights, + histo_bins = opt.histo_bins, + histo_threads = opt.histo_threads, content_layers = content_layers, content_weights = content_weights, deepdream_layers = deepdream_layers, diff --git a/train.lua b/train.lua index 1e80948..82efd06 100644 --- a/train.lua +++ b/train.lua @@ -43,6 +43,10 @@ cmd:option('-style_image_size', 256) cmd:option('-style_weights', '5.0') cmd:option('-style_layers', '4,9,16,23') cmd:option('-style_target_type', 'gram', 'gram|mean') +cmd:option('-histo_weights', '5.0') +cmd:option('-histo_layers', '4,23') +cmd:option('-histo_bins', 256) +cmd:option('-histo_threads', 4) -- Upsampling options cmd:option('-upsample_factor', 4) @@ -75,6 +79,8 @@ cmd:option('-backend', 'cuda', 'cuda|opencl') utils.parse_layers(opt.content_layers, opt.content_weights) opt.style_layers, opt.style_weights = utils.parse_layers(opt.style_layers, opt.style_weights) + opt.histo_layers, opt.histo_weights = + utils.parse_layers(opt.histo_layers, opt.histo_weights) -- Figure out preprocessing if not preprocess[opt.preprocessing] then @@ -119,6 +125,10 @@ cmd:option('-backend', 'cuda', 'cuda|opencl') cnn = loss_net, style_layers = opt.style_layers, style_weights = opt.style_weights, + histo_layers = opt.histo_layers, + histo_weights = opt.histo_weights, + histo_bins = opt.histo_bins, + histo_threads = opt.histo_threads, content_layers = opt.content_layers, content_weights = opt.content_weights, agg_type = opt.style_target_type, @@ -227,6 +237,9 @@ cmd:option('-backend', 'cuda', 'cuda|opencl') for i, k in ipairs(opt.style_layers) do style_loss_history[string.format('style-%d', k)] = {} end + for i, k in ipairs(opt.histo_layers) do + style_loss_history[string.format('histo-%d', k)] = {} + end for i, k in ipairs(opt.content_layers) do style_loss_history[string.format('content-%d', k)] = {} end @@ -245,6 +258,10 @@ cmd:option('-backend', 'cuda', 'cuda|opencl') table.insert(style_loss_history[string.format('style-%d', k)], percep_crit.style_losses[i]) end + for i, k in ipairs(opt.histo_layers) do + table.insert(style_loss_history[string.format('histo-%d', k)], + percep_crit.histo_losses[i]) + end for i, k in ipairs(opt.content_layers) do table.insert(style_loss_history[string.format('content-%d', k)], percep_crit.content_losses[i])