From f6c9cb99a12755101789ea31d45055b7e966a60e Mon Sep 17 00:00:00 2001 From: Angus Russell Date: Mon, 14 Jun 2021 11:48:37 +1000 Subject: [PATCH 1/2] Allow specifying custom scales and iterations --- style_transfer/style_transfer.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/style_transfer/style_transfer.py b/style_transfer/style_transfer.py index fdabd15..ccf0915 100644 --- a/style_transfer/style_transfer.py +++ b/style_transfer/style_transfer.py @@ -209,16 +209,16 @@ def size_to_fit(size, max_dim, scale_up=False): new_h = round(max_dim * h / w) return new_w, new_h - -def gen_scales(start, end): +def gen_scales(start, end, iterations, initial_iterations): scale = end i = 0 - scales = set() + scales = [] while scale >= start: - scales.add(scale) + scales.insert(0, (scale, iterations)) i += 1 scale = round(end / pow(2, i/2)) - return sorted(scales) + scales[0] = (scales[0][0], initial_iterations) + return scales def interpolate(*args, **kwargs): @@ -319,9 +319,10 @@ def stylize(self, content_image, style_images, *, tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight) - scales = gen_scales(min_scale, end_scale) + if scales is None: + scales = gen_scales(min_scale, end_scale, iterations, initial_iterations) - cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True) + cw, ch = size_to_fit(content_image.size, scales[0][0], scale_up=True) if init == 'content': self.image = TF.to_tensor(content_image.resize((cw, ch), Image.LANCZOS))[None] elif init == 'gray': @@ -341,7 +342,7 @@ def stylize(self, content_image, style_images, *, # Stylize the image at successively finer scales, each greater by a factor of sqrt(2). # This differs from the scheme given in Gatys et al. (2016). - for scale in scales: + for scale, its in scales: if self.devices[0].type == 'cuda': torch.cuda.empty_cache() @@ -385,7 +386,7 @@ def stylize(self, content_image, style_images, *, opt2 = optim.Adam([self.image], lr=step_size) # Warm-start the Adam optimizer if this is not the first scale. - if scale != scales[0]: + if scale != scales[0][0]: opt_state = scale_adam(opt.state_dict(), (ch, cw)) opt2.load_state_dict(opt_state) opt = opt2 @@ -393,8 +394,7 @@ def stylize(self, content_image, style_images, *, if self.devices[0].type == 'cuda': torch.cuda.empty_cache() - actual_its = initial_iterations if scale == scales[0] else iterations - for i in range(1, actual_its + 1): + for i in range(1, its + 1): feats = self.model(self.image) loss = crit(feats) opt.zero_grad() @@ -409,7 +409,7 @@ def stylize(self, content_image, style_images, *, for device in self.devices: if device.type == 'cuda': gpu_ram = max(gpu_ram, torch.cuda.max_memory_allocated(device)) - callback(STIterate(w=cw, h=ch, i=i, i_max=actual_its, loss=loss.item(), + callback(STIterate(w=cw, h=ch, i=i, i_max=its, loss=loss.item(), time=time.time(), gpu_ram=gpu_ram)) # Initialize each new scale with the previous scale's averaged iterate. From 7cfcfad2fb8488eb6152a4767179246f5cd14315 Mon Sep 17 00:00:00 2001 From: Angus Russell Date: Mon, 14 Jun 2021 12:15:58 +1000 Subject: [PATCH 2/2] Add scales argument - accidentally omitted from previous commit --- style_transfer/style_transfer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/style_transfer/style_transfer.py b/style_transfer/style_transfer.py index ccf0915..26e2fd9 100644 --- a/style_transfer/style_transfer.py +++ b/style_transfer/style_transfer.py @@ -297,6 +297,7 @@ def stylize(self, content_image, style_images, *, tv_weight: float = 2., min_scale: int = 128, end_scale: int = 512, + scales=None, iterations: int = 500, initial_iterations: int = 1000, step_size: float = 0.02,