diff --git a/README.md b/README.md
index d85e62a..340f622 100644
--- a/README.md
+++ b/README.md
@@ -100,6 +100,45 @@ Clockwise from upper left: "The Starry Night" + "The Scream", "The Scream" + "Co
+### Learn from multiple styles with GAN
+When using hundreds of pictures as style images, a discriminator could be used to calculate the style loss. The discriminator takes gram matrix as input and was trained to tell whether the generated image belongs to the target style.
+
+**The traditional way of calculating sytle loss**:
+
+**The new way of calculating style loss**:
+
+
+#### Results
+##### Imitate Shinkai Makoto Style
+Transfered with ~160 high quality style images.
+
+##### Imitate Monet(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN))
+
+##### Imitate Vangogh(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN))
+
+
+
+#### Usage
+1. Download style image set(borrowed from CycleGAN):
+ `bash ./datasets/download_dataset.sh `
+
+ `` could be monet2photo, vangogh2photo, ukiyoe2photo, cezanne2photo
+2. Do style transfer
+ ```
+ th neural_style.lua -style_image `./list_images.sh ` -content_ -gan -content_weight 2 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations 10000 -d_learning_rate 0.000001`
+ ```
+
+ `-gan`command specifies using Discriminators to calculate style losses. `d_learning_rate` is the learning rate for Discriminators. `list_images.sh` helps to list all images in one directory, all files in that directory should not contain space and `style_image_dir`should not contain`~`. You need to play with parameters for different style and size.
+#### example
+Transfer fj.jpg to vangogh style
+1. Download vangogh's painting `bash ./datasets/download_dataset.sh vangogh2photo`
+2. Add styles to image
+ ```
+ th neural_style.lua -style_image `./list_images.sh datasets/vangogh2photo/trainA
+` -content_image data/fj.jpg -gan -content_weight 1 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations
+ 10000 -d_learning_rate 0.0000001
+ ```
+
### Style Interpolation
When using multiple style images, you can control the degree to which they are blended:
diff --git a/data/cmp_manual.png b/data/cmp_manual.png
new file mode 100644
index 0000000..e83df37
Binary files /dev/null and b/data/cmp_manual.png differ
diff --git a/data/fj.jpg b/data/fj.jpg
new file mode 100644
index 0000000..0f82f95
Binary files /dev/null and b/data/fj.jpg differ
diff --git a/data/monet.png b/data/monet.png
new file mode 100644
index 0000000..357c604
Binary files /dev/null and b/data/monet.png differ
diff --git a/data/style-gan.png b/data/style-gan.png
new file mode 100644
index 0000000..01ffad4
Binary files /dev/null and b/data/style-gan.png differ
diff --git a/data/style-gan2.png b/data/style-gan2.png
new file mode 100644
index 0000000..88cd7f5
Binary files /dev/null and b/data/style-gan2.png differ
diff --git a/data/vangogh.png b/data/vangogh.png
new file mode 100644
index 0000000..52437db
Binary files /dev/null and b/data/vangogh.png differ
diff --git a/datasets/download_dataset.sh b/datasets/download_dataset.sh
new file mode 100644
index 0000000..1f0b163
--- /dev/null
+++ b/datasets/download_dataset.sh
@@ -0,0 +1,14 @@
+FILE=$1
+
+if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
+ echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
+ exit 1
+fi
+
+URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
+ZIP_FILE=./datasets/$FILE.zip
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $ZIP_FILE
+mkdir $TARGET_DIR
+unzip $ZIP_FILE -d ./datasets/
+rm $ZIP_FILE
diff --git a/list_images.sh b/list_images.sh
new file mode 100755
index 0000000..d6eb192
--- /dev/null
+++ b/list_images.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+if [ $# -eq 1 ];
+ then
+ a=`find $1 -type f`
+else
+ a=`find $1 -type f -name "*.$2"`
+fi
+b=$(echo "$a" | tr '\n' ,)
+b=${b::-1}
+echo $b
diff --git a/neural_style.lua b/neural_style.lua
index adc7621..a6eb6c3 100644
--- a/neural_style.lua
+++ b/neural_style.lua
@@ -2,7 +2,7 @@ require 'torch'
require 'nn'
require 'image'
require 'optim'
-
+require 'io'
require 'loadcaffe'
@@ -48,10 +48,16 @@ cmd:option('-seed', -1)
cmd:option('-content_layers', 'relu4_2', 'layers for content')
cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style')
+--gan or not
+cmd:option('-gan',false,'use a discriminator to calculate style_losses')
+--gan learning rate
+cmd:option('-d_learning_rate',0.0000001)
-local function main(params)
- local dtype, multigpu = setup_gpu(params)
+local function main(params)
+ io.stdout:setvbuf('no')
+ --local dtype, multigpu = setup_gpu(params)
+ dtype, multigpu = setup_gpu(params)
local loadcaffe_backend = params.backend
if params.backend == 'clnn' then loadcaffe_backend = 'nn' end
local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):type(dtype)
@@ -222,7 +228,7 @@ local function main(params)
local optim_state = nil
if params.optimizer == 'lbfgs' then
optim_state = {
- maxIter = params.num_iterations,
+ maxIter = 1,
verbose=true,
tolX=-1,
tolFun=-1,
@@ -297,14 +303,72 @@ local function main(params)
return loss, grad:view(grad:nElement())
end
+ --generate optimize function for discriminator for one style layer's loss
+ local function get_discriminator_loss(style_layer)
+ --get params and gradients of discriminator
+ local dis_params, gradParams = style_layer.D.model:getParameters()
+ --create input for discriminator: one real gram, one fake gram
+ local batchInputs = Tensor(2 , style_layer.target[1]:size(1), style_layer.target[1]:size(2))
+ --create label set
+ local batchLabels = torch.zeros(2) + 1
+ batchLabels[-1] = 0
+ batchLabels = batchLabels:type(dtype)
+ local crit = nn.BCECriterion():type(dtype)
+ d_loss = function (params)
+ --randomly choose a real image as positive input
+ batchInputs[1]:copy(style_layer.target[math.random(#style_layer.target)])
+ --copy fake images to training set
+ batchInputs[2]:copy(style_layer.G)
+ --change dtype
+ batchInputs = batchInputs:type(dtype)
+ --create loss
+ local output = style_layer.D:forward(batchInputs)
+ local loss = crit:forward(output, batchLabels)
+ local verbose = (num_calls % 50 == 0)
+ if verbose then
+ print(string.format('discriminator loss %f',loss))
+ end
+
+ --refresh gradients of discriminator
+ local d_output = crit:backward(output, batchLabels)
+ gradParams:zero()
+ style_layer.D:backward(batchInputs, d_output)
+ return loss, gradParams
+ end
+ return d_loss, dis_params
+ end
+
+ function train_discriminator()
+ if params.gan then
+ --print ('refresh gan losses')
+ for i =1, #style_losses do --train discriminator for each sytle layer
+ --local f, dis_params = get_discriminator_loss(style_losses[i])
+ local x, losses = optim.adam(style_losses[i].f, style_losses[i].dis_params, style_losses[i].optim_state)
+ end
+ end
+ end
+
+ --init gan optimization eval functions
+ if params.gan then
+ for i = 1, #style_losses do
+ f, dis_params = get_discriminator_loss(style_losses[i])
+ style_losses[i].f = f
+ style_losses[i].dis_params = dis_params
+ end
+ end
+
-- Run optimization.
if params.optimizer == 'lbfgs' then
print('Running optimization with L-BFGS')
- local x, losses = optim.lbfgs(feval, img, optim_state)
+ for t = 1, params.num_iterations do
+ local x, losses = optim.lbfgs(feval, img, optim_state)
+ train_discriminator()
+ end
elseif params.optimizer == 'adam' then
print('Running optimization with ADAM')
for t = 1, params.num_iterations do
local x, losses = optim.adam(feval, img, optim_state)
+ train_discriminator()
end
end
end
@@ -321,6 +385,7 @@ function setup_gpu(params)
else
params.gpu = tonumber(params.gpu) + 1
end
+ Tensor= torch.FloatTensor
local dtype = 'torch.FloatTensor'
if multigpu or params.gpu > 0 then
if params.backend ~= 'clnn' then
@@ -331,6 +396,7 @@ function setup_gpu(params)
else
cutorch.setDevice(params.gpu)
end
+ Tensor = torch.CudaTensor
dtype = 'torch.CudaTensor'
else
require 'clnn'
@@ -340,6 +406,7 @@ function setup_gpu(params)
else
cltorch.setDevice(params.gpu)
end
+ Tensor = torch.Tensor():cl()
dtype = torch.Tensor():cl():type()
end
else
@@ -519,7 +586,15 @@ function StyleLoss:__init(strength, normalize)
parent.__init(self)
self.normalize = normalize or false
self.strength = strength
- self.target = torch.Tensor()
+ --if gan, save all grams
+ if params.gan then
+ self.target = {}
+ self.optim_state = {
+ learningRate = params.d_learning_rate,
+ }
+ else
+ self.target = torch.Tensor()
+ end
self.mode = 'none'
self.loss = 0
@@ -527,6 +602,9 @@ function StyleLoss:__init(strength, normalize)
self.blend_weight = nil
self.G = nil
self.crit = nn.MSECriterion()
+ if params.gan then
+ self.D = nil
+ end
end
function StyleLoss:updateOutput(input)
@@ -535,21 +613,45 @@ function StyleLoss:updateOutput(input)
if self.mode == 'capture' then
if self.blend_weight == nil then
self.target:resizeAs(self.G):copy(self.G)
+ --if gan mode, store every image gram
+ elseif params.gan then
+ gram_of_this_style = torch.Tensor():type(dtype):resizeAs(self.G):copy(self.G)
+ table.insert(self.target, gram_of_this_style)
elseif self.target:nElement() == 0 then
self.target:resizeAs(self.G):copy(self.G):mul(self.blend_weight)
+
else
- self.target:add(self.blend_weight, self.G)
+ self.target:add(self.blend_weight, self.G)
end
elseif self.mode == 'loss' then
- self.loss = self.strength * self.crit:forward(self.G, self.target)
- end
+ if params.gan then --if gan mode
+ --create D after we knew the dim of Gram Matrix
+ if self.D == nil then
+ self.D = nn.Discriminator(self.G:size(1),self.G:size(2))
+ self.D:type(dtype)
+ end
+ --classify gram matrix
+ self.classified = self.D:forward(self.G)
+ -- print(self.classified)
+ --hope it looks like the target style
+ self.loss = self.strength * self.crit:forward(self.classified,torch.Tensor({1}):type(dtype))
+ else --if not gan mode
+ self.loss = self.strength * self.crit:forward(self.G, self.target)
+ end
+ end
self.output = input
return self.output
end
function StyleLoss:updateGradInput(input, gradOutput)
if self.mode == 'loss' then
- local dG = self.crit:backward(self.G, self.target)
+ local dG
+ if params.gan then
+ d_classified = self.crit:backward(self.classified,torch.Tensor({1}):type(dtype))
+ dG = self.D:backward(self.G, d_classified)
+ else
+ dG = self.crit:backward(self.G, self.target)
+ end
dG:div(input:nElement())
self.gradInput = self.gram:backward(input, dG)
if self.normalize then
@@ -596,6 +698,46 @@ function TVLoss:updateGradInput(input, gradOutput)
return self.gradInput
end
+-- Define an nn Module to compute style loss with a discriminator in-place
+local Discriminator, parent = torch.class('nn.Discriminator', 'nn.Module')
+
+function Discriminator:__init(input_H, input_W)
+ --define a simple discriminator
+ self.model = nn.Sequential()
+ --flatten the gram matrix
+ self.model:add(nn.View(input_H * input_W))
+ --hidden layer 1
+ self.model:add(nn.Linear(input_H * input_W,input_W))
+ self.model:add(nn.ReLU())
+ -- --hidden layer 2
+ -- self.model:add(nn.Linear(input_H * input_W / 4,input_H * input_W / 8))
+ -- self.model:add(nn.ReLU())
+ -- --hidden layer 3
+ self.model:add(nn.Linear(input_W ,input_W))
+ self.model:add(nn.ReLU())
+
+ --output layer
+ self.model:add(nn.Linear(input_W , 1))
+ self.model:add(nn.Sigmoid())
+ self.model = self.model:type(dtype)
+end
+
+--forward of discriminator
+function Discriminator:updateOutput(input)
+ --input is the Gram Matrix
+ input = input:type(dtype)
+ self.output = self.model:forward(input)
+ return self.output
+end
+
+--backward of discriminator
+function Discriminator:updateGradInput(input, gradOutput)
+ self.gradInput = self.model:backward(input,gradOutput)
+ return self.gradInput
+end
+
+
-local params = cmd:parse(arg)
+-- local params = cmd:parse(arg)
+params = cmd:parse(arg)
main(params)