Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions style_transfer/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,16 @@ def size_to_fit(size, max_dim, scale_up=False):
new_h = round(max_dim * h / w)
return new_w, new_h


def gen_scales(start, end):
def gen_scales(start, end, iterations, initial_iterations):
scale = end
i = 0
scales = set()
scales = []
while scale >= start:
scales.add(scale)
scales.insert(0, (scale, iterations))
i += 1
scale = round(end / pow(2, i/2))
return sorted(scales)
scales[0] = (scales[0][0], initial_iterations)
return scales


def interpolate(*args, **kwargs):
Expand Down Expand Up @@ -297,6 +297,7 @@ def stylize(self, content_image, style_images, *,
tv_weight: float = 2.,
min_scale: int = 128,
end_scale: int = 512,
scales=None,
iterations: int = 500,
initial_iterations: int = 1000,
step_size: float = 0.02,
Expand All @@ -319,9 +320,10 @@ def stylize(self, content_image, style_images, *,

tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight)

scales = gen_scales(min_scale, end_scale)
if scales is None:
scales = gen_scales(min_scale, end_scale, iterations, initial_iterations)

cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True)
cw, ch = size_to_fit(content_image.size, scales[0][0], scale_up=True)
if init == 'content':
self.image = TF.to_tensor(content_image.resize((cw, ch), Image.LANCZOS))[None]
elif init == 'gray':
Expand All @@ -341,7 +343,7 @@ def stylize(self, content_image, style_images, *,

# Stylize the image at successively finer scales, each greater by a factor of sqrt(2).
# This differs from the scheme given in Gatys et al. (2016).
for scale in scales:
for scale, its in scales:
if self.devices[0].type == 'cuda':
torch.cuda.empty_cache()

Expand Down Expand Up @@ -385,16 +387,15 @@ def stylize(self, content_image, style_images, *,

opt2 = optim.Adam([self.image], lr=step_size)
# Warm-start the Adam optimizer if this is not the first scale.
if scale != scales[0]:
if scale != scales[0][0]:
opt_state = scale_adam(opt.state_dict(), (ch, cw))
opt2.load_state_dict(opt_state)
opt = opt2

if self.devices[0].type == 'cuda':
torch.cuda.empty_cache()

actual_its = initial_iterations if scale == scales[0] else iterations
for i in range(1, actual_its + 1):
for i in range(1, its + 1):
feats = self.model(self.image)
loss = crit(feats)
opt.zero_grad()
Expand All @@ -409,7 +410,7 @@ def stylize(self, content_image, style_images, *,
for device in self.devices:
if device.type == 'cuda':
gpu_ram = max(gpu_ram, torch.cuda.max_memory_allocated(device))
callback(STIterate(w=cw, h=ch, i=i, i_max=actual_its, loss=loss.item(),
callback(STIterate(w=cw, h=ch, i=i, i_max=its, loss=loss.item(),
time=time.time(), gpu_ram=gpu_ram))

# Initialize each new scale with the previous scale's averaged iterate.
Expand Down