From e02d08317905771eccdcd931956ea286e6c9e93c Mon Sep 17 00:00:00 2001 From: "veronika.spieker" Date: Wed, 9 Aug 2023 15:34:03 +0200 Subject: [PATCH] add K & KT as customizable input for recon --- medutils/optimization/base_optimizer.py | 8 ++++++- .../optimization_th/recon_optimizer_th.py | 21 ++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/medutils/optimization/base_optimizer.py b/medutils/optimization/base_optimizer.py index fc3eb61..286483b 100644 --- a/medutils/optimization/base_optimizer.py +++ b/medutils/optimization/base_optimizer.py @@ -29,7 +29,13 @@ def solve(self, f, max_iter): raise NotImplementedError class BaseReconOptimizer(BaseOptimizer): - def __init__(self, A, AH, mode, lambd, beta=None, tau=None): + def __init__(self, A, AH, mode, lambd, K=None, KT=None, beta=None, tau=None): + ''' + :param K: expected to be a lamdba function, e.g. K = lambda x, beta, mode: Nabla(mode=mode, beta=beta).forward(x) + :param KT: expected to be a lamdba function, e.g. KT = lambda x, beta, mode: NablaT(mode=mode, beta=beta).forward(x) + ''' self.A = A self.AH = AH + self.K = K + self.KT = KT super().__init__(mode, lambd, beta, tau) \ No newline at end of file diff --git a/medutils/optimization_th/recon_optimizer_th.py b/medutils/optimization_th/recon_optimizer_th.py index 870aa9e..1b3fec1 100644 --- a/medutils/optimization_th/recon_optimizer_th.py +++ b/medutils/optimization_th/recon_optimizer_th.py @@ -14,6 +14,7 @@ import tqdm import torch import numpy as np +from functools import partial class TVReconOptimizer(BaseReconOptimizer): """ Total Variation @@ -24,14 +25,14 @@ class TVReconOptimizer(BaseReconOptimizer): """ def solve(self, y, max_iter): # setup operators - K = Nabla(self.mode, self.beta) - KT = NablaT(self.mode, self.beta) + K = Nabla(self.mode, self.beta) if self.K is None else partial(self.K, mode=self.mode, beta=self.beta) + KT = NablaT(self.mode, self.beta) if self.KT is None else partial(self.KT, mode=self.mode, beta=self.beta) A = self.A AH = self.AH # setup constants - L = K.L + L = Nabla(self.mode, self.beta).L # ToDo: Adjust to input K if self.tau != None: tau = self.tau else: @@ -86,7 +87,7 @@ def __init__(self, A, AH, mode, lambd, alpha0, alpha1, beta=None): self.alpha1 = alpha1 def solve(self, y, max_iter): - # setup operators + # setup operator # ToDo: Adapt to custom operators K = Nabla(self.mode, self.beta) KT = NablaT(self.mode, self.beta) E = NablaSym(self.mode, self.beta) @@ -168,11 +169,11 @@ def __init__(self, A, AH, mode, lambd, alpha1, s, beta1=None, beta2=None): self.beta2 = beta2 def solve(self, y, max_iter): - # setup operators - K_beta1 = Nabla(self.mode, self.beta1) - KT_beta1 = NablaT(self.mode, self.beta1) - K_beta2 = Nabla(self.mode, self.beta2) - KT_beta2 = NablaT(self.mode, self.beta2) + # setup operators # ToDo: Adapt to custom operators + K_beta1 = Nabla(self.mode, self.beta1) if self.K is None else partial(self.K, mode=self.mode, beta=self.beta1) + KT_beta1 = NablaT(self.mode, self.beta1) if self.KT is None else partial(self.KT, mode=self.mode, beta=self.beta1) + K_beta2 = Nabla(self.mode, self.beta2) if self.K is None else partial(self.KT, mode=self.mode, beta=self.beta2) + KT_beta2 = NablaT(self.mode, self.beta2) if self.KT is None else partial(self.KT, mode=self.mode, beta=self.beta2) A = self.A AH = self.AH @@ -266,7 +267,7 @@ def __init__(self, A, AH, mode, lambd, alpha0, alpha1, s, beta1=None, beta2=None self.beta2 = beta2 def solve(self, y, max_iter): - # setup operators + # setup operators # ToDo: Adapt to custom operators K_beta1 = Nabla(self.mode, self.beta1) KT_beta1 = NablaT(self.mode, self.beta1) E_beta1 = NablaSym(self.mode, self.beta1)