-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTotalVariation.py
More file actions
181 lines (145 loc) · 5.38 KB
/
TotalVariation.py
File metadata and controls
181 lines (145 loc) · 5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# -*- coding: utf-8 -*-
# @Author:
# @Date: 2025-06-11 15:25:51
# @LastEditTime: 2025-06-13 15:59:47
# @FilePath: /image_denoise/TotalVariation.py
# @Description: TV-L1 denoising algorithm implementation
import numpy as np
from utils import *
from Method import Method
class TotalVariation(Method):
"""
TV去噪基类,默认实现TV-L1去噪算法
"""
def __init__(self, input: np.ndarray, lambda_=1.0, rho=0.07, max_iter=1000, cost_threshold=1e-6):
"""
Args:
input (ndarray): 输入图像(带噪声)(height, width)
lambda_ (float): 正则化参数
rho (float): ADMM参数更新步长,默认0.07
max_iter (int): 最大迭代次数
cost_threshold (float): 收敛容忍度
"""
super().__init__(input, max_iter, cost_threshold)
self.lambda_ = lambda_ # 正则化参数
self.rho = rho # ADMM参数,控制更新步长
self.method = 'TV-L1 ADMM'
# 预计算拉普拉斯算子的傅里叶变换
self.laplacian_fourier = self._laplacian_fourier()
# 初始化变量
self.u = np.copy(self.img) # 初始结果 (对应u)
self.grad_ux = np.zeros_like(self.img)
self.grad_uy = np.zeros_like(self.img)
# Result - Input (对应z)
self.z = np.zeros_like(self.img)
# 梯度 (对应y)
self.y_x = np.zeros_like(self.img) # 对应∇y_x
self.y_y = np.zeros_like(self.img) # 对应∇y_y
# 拉格朗日乘子 (对应aux变量)
self.a = np.zeros_like(self.img) # 对应a
self.b_x = np.zeros_like(self.img) # 对应b_x
self.b_y = np.zeros_like(self.img) # 对应b_y
# ADMM参数
self.alpha = 0.02
self.beta = 0.02
def _laplacian_fourier(self):
"""计算拉普拉斯算子的傅里叶变换"""
# 水平方向差分核
otf_x = psf2otf(np.array([[1, -1]]), [self.height, self.width])
# 垂直方向差分核
otf_y = psf2otf(np.array([[1], [-1]]), [self.height, self.width])
res = np.abs(otf_x) ** 2 + np.abs(otf_y) ** 2
return res
def _update_u(self):
# 步骤1: 更新result (对应u更新)
y_x_tmp = self.y_x + self.b_x / self.alpha # x方向:y + b/α
y_y_tmp = self.y_y + self.b_y / self.alpha # y方向:y + b/α
# 计算 F(z + f + a/β)
term1 = np.fft.fft2(self.z + self.img + self.a / self.beta)
# 计算散度 ∇T(y + b/α)
term2 = divergence(y_x_tmp, y_y_tmp)
# 分子 F(z + f + a/β) + α/β*F(∇T(y + b/α))
molecule = term1 + self.alpha / self.beta * np.fft.fft2(term2)
# 计算分母 1 + α/β*F(∇**2)
denomin = 1 + self.alpha / self.beta * self.laplacian_fourier
# 更新result(u)
self.u = np.real(np.fft.ifft2(molecule / denomin))
def _update_y(self):
# 计算梯度
result_x = self.grad_ux - self.b_x / self.alpha
result_y = self.grad_uy - self.b_y / self.alpha
self.y_x = soft_threshold(result_x, self.lambda_ / self.alpha)
self.y_y = soft_threshold(result_y, self.lambda_ / self.alpha)
def _update_z(self):
res = self.u - self.img - self.a / self.beta
self.z = soft_threshold(res, 1.0 / self.beta)
def _get_cost(self):
"""
计算当前代价函数
代价函数:||u - f||_1 + λ*||∇u||_1
Returns:
float: 当前代价
"""
return (np.sum(np.abs(self.u - self.img)) +
self.lambda_ * (np.sum(np.abs(self.grad_ux)+ np.abs(self.grad_uy)))
)
def _update(self):
# 步骤1: 更新result (对应u更新)
self._update_u()
# print("u_upd: ", self.u[0, 0])
# 更新梯度(常用)
self.grad_ux = gradient_x(self.u)
self.grad_uy = gradient_y(self.u)
# 步骤2: 更新梯度 (对应y更新)
self._update_y()
# print("y_upd: ", self.y_x[0, 0])
# 步骤3: 更新差值 (对应z更新)
self._update_z()
# print("z_upd: ", self.z[0, 0])
# 步骤4: 更新拉格朗日乘子
self.a += self.beta * (self.z - (self.u - self.img))
# print(f"迭代 {iter_count}: a = {self.a[0,0]}")
self.b_x += self.alpha * (self.y_x - self.grad_ux)
# print(f"迭代 {iter_count}: b_x = {self.b_x[0,0]}")
self.b_y += self.alpha * (self.y_y - self.grad_uy)
# 步骤5: 更新惩罚参数
self.alpha += self.rho
self.beta += self.rho
class TV_L1(TotalVariation):
'''
TV-L1 ADMM图像去噪算法实现,算法见README.md
目标函数:min(u) ||u - f||_1 + λ*||∇u||_1
其中:
- u: 去噪后的图像
- f: 输入图像(带噪声)
- λ: 正则化参数
'''
pass
class TV_L2(TV_L1):
"""TV-L2 ADMM图像去噪算法实现,算法见README.md
目标函数:min(u) (1/2)||u - f||_2^2 + λ*||∇u||_1
其中:
- u: 去噪后的图像
- f: 输入图像(带噪声)
- λ: 正则化参数
"""
def __init__(self, input, lambda_=7, rho=0.07, max_iter=1000, cost_threshold=1e-6):
super().__init__(input, lambda_, rho, max_iter, cost_threshold)
self.method = "TV_L2 ADMM"
def _update_z(self):
"""
更新差值 (对应z更新)
根据推导过程,L1和L2仅z的更新方式不同
"""
tmp = self.u - self.img - self.a / self.beta
self.z = self.beta / (1 + self.beta) * tmp
def _get_cost(self):
"""
计算当前代价函数
代价函数:1/2 * ||u - f||_2^2 + λ * ||∇u||_1
Returns:
float: 当前代价
"""
return (0.5 * np.sum((self.u - self.img) ** 2) +
self.lambda_ * (np.sum(np.abs(self.grad_ux) + np.abs(self.grad_uy)))
)