diff --git a/README.md b/README.md index d85e62a..340f622 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,45 @@ Clockwise from upper left: "The Starry Night" + "The Scream", "The Scream" + "Co +### Learn from multiple styles with GAN +When using hundreds of pictures as style images, a discriminator could be used to calculate the style loss. The discriminator takes gram matrix as input and was trained to tell whether the generated image belongs to the target style. + +**The traditional way of calculating sytle loss**:![](https://github.com/citymonkeymao/neural-style/blob/gan/data/style-gan.png?raw=true) + +**The new way of calculating style loss**: +![](https://github.com/citymonkeymao/neural-style/blob/gan/data/style-gan2.png?raw=true) + +#### Results +##### Imitate Shinkai Makoto Style +Transfered with ~160 high quality style images. +![](https://github.com/citymonkeymao/neural-style/blob/gan/data/cmp_manual.png?raw=true) +##### Imitate Monet(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN)) +![](https://github.com/citymonkeymao/neural-style/blob/gan/data/monet.png?raw=true) +##### Imitate Vangogh(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN)) +![](https://github.com/citymonkeymao/neural-style/blob/gan/data/vangogh.png?raw=true) + + +#### Usage +1. Download style image set(borrowed from CycleGAN): + `bash ./datasets/download_dataset.sh ` + + `` could be monet2photo, vangogh2photo, ukiyoe2photo, cezanne2photo +2. Do style transfer + ``` + th neural_style.lua -style_image `./list_images.sh ` -content_ -gan -content_weight 2 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations 10000 -d_learning_rate 0.000001` + ``` + + `-gan`command specifies using Discriminators to calculate style losses. `d_learning_rate` is the learning rate for Discriminators. `list_images.sh` helps to list all images in one directory, all files in that directory should not contain space and `style_image_dir`should not contain`~`. You need to play with parameters for different style and size. +#### example +Transfer fj.jpg to vangogh style +1. Download vangogh's painting `bash ./datasets/download_dataset.sh vangogh2photo` +2. Add styles to image + ``` + th neural_style.lua -style_image `./list_images.sh datasets/vangogh2photo/trainA +` -content_image data/fj.jpg -gan -content_weight 1 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations + 10000 -d_learning_rate 0.0000001 + ``` + ### Style Interpolation When using multiple style images, you can control the degree to which they are blended: diff --git a/data/cmp_manual.png b/data/cmp_manual.png new file mode 100644 index 0000000..e83df37 Binary files /dev/null and b/data/cmp_manual.png differ diff --git a/data/fj.jpg b/data/fj.jpg new file mode 100644 index 0000000..0f82f95 Binary files /dev/null and b/data/fj.jpg differ diff --git a/data/monet.png b/data/monet.png new file mode 100644 index 0000000..357c604 Binary files /dev/null and b/data/monet.png differ diff --git a/data/style-gan.png b/data/style-gan.png new file mode 100644 index 0000000..01ffad4 Binary files /dev/null and b/data/style-gan.png differ diff --git a/data/style-gan2.png b/data/style-gan2.png new file mode 100644 index 0000000..88cd7f5 Binary files /dev/null and b/data/style-gan2.png differ diff --git a/data/vangogh.png b/data/vangogh.png new file mode 100644 index 0000000..52437db Binary files /dev/null and b/data/vangogh.png differ diff --git a/datasets/download_dataset.sh b/datasets/download_dataset.sh new file mode 100644 index 0000000..1f0b163 --- /dev/null +++ b/datasets/download_dataset.sh @@ -0,0 +1,14 @@ +FILE=$1 + +if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then + echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" + exit 1 +fi + +URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip +ZIP_FILE=./datasets/$FILE.zip +TARGET_DIR=./datasets/$FILE/ +wget -N $URL -O $ZIP_FILE +mkdir $TARGET_DIR +unzip $ZIP_FILE -d ./datasets/ +rm $ZIP_FILE diff --git a/list_images.sh b/list_images.sh new file mode 100755 index 0000000..d6eb192 --- /dev/null +++ b/list_images.sh @@ -0,0 +1,10 @@ +#!/bin/bash +if [ $# -eq 1 ]; + then + a=`find $1 -type f` +else + a=`find $1 -type f -name "*.$2"` +fi +b=$(echo "$a" | tr '\n' ,) +b=${b::-1} +echo $b diff --git a/neural_style.lua b/neural_style.lua index adc7621..a6eb6c3 100644 --- a/neural_style.lua +++ b/neural_style.lua @@ -2,7 +2,7 @@ require 'torch' require 'nn' require 'image' require 'optim' - +require 'io' require 'loadcaffe' @@ -48,10 +48,16 @@ cmd:option('-seed', -1) cmd:option('-content_layers', 'relu4_2', 'layers for content') cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style') +--gan or not +cmd:option('-gan',false,'use a discriminator to calculate style_losses') +--gan learning rate +cmd:option('-d_learning_rate',0.0000001) -local function main(params) - local dtype, multigpu = setup_gpu(params) +local function main(params) + io.stdout:setvbuf('no') + --local dtype, multigpu = setup_gpu(params) + dtype, multigpu = setup_gpu(params) local loadcaffe_backend = params.backend if params.backend == 'clnn' then loadcaffe_backend = 'nn' end local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):type(dtype) @@ -222,7 +228,7 @@ local function main(params) local optim_state = nil if params.optimizer == 'lbfgs' then optim_state = { - maxIter = params.num_iterations, + maxIter = 1, verbose=true, tolX=-1, tolFun=-1, @@ -297,14 +303,72 @@ local function main(params) return loss, grad:view(grad:nElement()) end + --generate optimize function for discriminator for one style layer's loss + local function get_discriminator_loss(style_layer) + --get params and gradients of discriminator + local dis_params, gradParams = style_layer.D.model:getParameters() + --create input for discriminator: one real gram, one fake gram + local batchInputs = Tensor(2 , style_layer.target[1]:size(1), style_layer.target[1]:size(2)) + --create label set + local batchLabels = torch.zeros(2) + 1 + batchLabels[-1] = 0 + batchLabels = batchLabels:type(dtype) + local crit = nn.BCECriterion():type(dtype) + d_loss = function (params) + --randomly choose a real image as positive input + batchInputs[1]:copy(style_layer.target[math.random(#style_layer.target)]) + --copy fake images to training set + batchInputs[2]:copy(style_layer.G) + --change dtype + batchInputs = batchInputs:type(dtype) + --create loss + local output = style_layer.D:forward(batchInputs) + local loss = crit:forward(output, batchLabels) + local verbose = (num_calls % 50 == 0) + if verbose then + print(string.format('discriminator loss %f',loss)) + end + + --refresh gradients of discriminator + local d_output = crit:backward(output, batchLabels) + gradParams:zero() + style_layer.D:backward(batchInputs, d_output) + return loss, gradParams + end + return d_loss, dis_params + end + + function train_discriminator() + if params.gan then + --print ('refresh gan losses') + for i =1, #style_losses do --train discriminator for each sytle layer + --local f, dis_params = get_discriminator_loss(style_losses[i]) + local x, losses = optim.adam(style_losses[i].f, style_losses[i].dis_params, style_losses[i].optim_state) + end + end + end + + --init gan optimization eval functions + if params.gan then + for i = 1, #style_losses do + f, dis_params = get_discriminator_loss(style_losses[i]) + style_losses[i].f = f + style_losses[i].dis_params = dis_params + end + end + -- Run optimization. if params.optimizer == 'lbfgs' then print('Running optimization with L-BFGS') - local x, losses = optim.lbfgs(feval, img, optim_state) + for t = 1, params.num_iterations do + local x, losses = optim.lbfgs(feval, img, optim_state) + train_discriminator() + end elseif params.optimizer == 'adam' then print('Running optimization with ADAM') for t = 1, params.num_iterations do local x, losses = optim.adam(feval, img, optim_state) + train_discriminator() end end end @@ -321,6 +385,7 @@ function setup_gpu(params) else params.gpu = tonumber(params.gpu) + 1 end + Tensor= torch.FloatTensor local dtype = 'torch.FloatTensor' if multigpu or params.gpu > 0 then if params.backend ~= 'clnn' then @@ -331,6 +396,7 @@ function setup_gpu(params) else cutorch.setDevice(params.gpu) end + Tensor = torch.CudaTensor dtype = 'torch.CudaTensor' else require 'clnn' @@ -340,6 +406,7 @@ function setup_gpu(params) else cltorch.setDevice(params.gpu) end + Tensor = torch.Tensor():cl() dtype = torch.Tensor():cl():type() end else @@ -519,7 +586,15 @@ function StyleLoss:__init(strength, normalize) parent.__init(self) self.normalize = normalize or false self.strength = strength - self.target = torch.Tensor() + --if gan, save all grams + if params.gan then + self.target = {} + self.optim_state = { + learningRate = params.d_learning_rate, + } + else + self.target = torch.Tensor() + end self.mode = 'none' self.loss = 0 @@ -527,6 +602,9 @@ function StyleLoss:__init(strength, normalize) self.blend_weight = nil self.G = nil self.crit = nn.MSECriterion() + if params.gan then + self.D = nil + end end function StyleLoss:updateOutput(input) @@ -535,21 +613,45 @@ function StyleLoss:updateOutput(input) if self.mode == 'capture' then if self.blend_weight == nil then self.target:resizeAs(self.G):copy(self.G) + --if gan mode, store every image gram + elseif params.gan then + gram_of_this_style = torch.Tensor():type(dtype):resizeAs(self.G):copy(self.G) + table.insert(self.target, gram_of_this_style) elseif self.target:nElement() == 0 then self.target:resizeAs(self.G):copy(self.G):mul(self.blend_weight) + else - self.target:add(self.blend_weight, self.G) + self.target:add(self.blend_weight, self.G) end elseif self.mode == 'loss' then - self.loss = self.strength * self.crit:forward(self.G, self.target) - end + if params.gan then --if gan mode + --create D after we knew the dim of Gram Matrix + if self.D == nil then + self.D = nn.Discriminator(self.G:size(1),self.G:size(2)) + self.D:type(dtype) + end + --classify gram matrix + self.classified = self.D:forward(self.G) + -- print(self.classified) + --hope it looks like the target style + self.loss = self.strength * self.crit:forward(self.classified,torch.Tensor({1}):type(dtype)) + else --if not gan mode + self.loss = self.strength * self.crit:forward(self.G, self.target) + end + end self.output = input return self.output end function StyleLoss:updateGradInput(input, gradOutput) if self.mode == 'loss' then - local dG = self.crit:backward(self.G, self.target) + local dG + if params.gan then + d_classified = self.crit:backward(self.classified,torch.Tensor({1}):type(dtype)) + dG = self.D:backward(self.G, d_classified) + else + dG = self.crit:backward(self.G, self.target) + end dG:div(input:nElement()) self.gradInput = self.gram:backward(input, dG) if self.normalize then @@ -596,6 +698,46 @@ function TVLoss:updateGradInput(input, gradOutput) return self.gradInput end +-- Define an nn Module to compute style loss with a discriminator in-place +local Discriminator, parent = torch.class('nn.Discriminator', 'nn.Module') + +function Discriminator:__init(input_H, input_W) + --define a simple discriminator + self.model = nn.Sequential() + --flatten the gram matrix + self.model:add(nn.View(input_H * input_W)) + --hidden layer 1 + self.model:add(nn.Linear(input_H * input_W,input_W)) + self.model:add(nn.ReLU()) + -- --hidden layer 2 + -- self.model:add(nn.Linear(input_H * input_W / 4,input_H * input_W / 8)) + -- self.model:add(nn.ReLU()) + -- --hidden layer 3 + self.model:add(nn.Linear(input_W ,input_W)) + self.model:add(nn.ReLU()) + + --output layer + self.model:add(nn.Linear(input_W , 1)) + self.model:add(nn.Sigmoid()) + self.model = self.model:type(dtype) +end + +--forward of discriminator +function Discriminator:updateOutput(input) + --input is the Gram Matrix + input = input:type(dtype) + self.output = self.model:forward(input) + return self.output +end + +--backward of discriminator +function Discriminator:updateGradInput(input, gradOutput) + self.gradInput = self.model:backward(input,gradOutput) + return self.gradInput +end + + -local params = cmd:parse(arg) +-- local params = cmd:parse(arg) +params = cmd:parse(arg) main(params)