-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
267 lines (232 loc) · 8.48 KB
/
main.py
File metadata and controls
267 lines (232 loc) · 8.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# -*- coding: utf-8 -*-
# @Author:
# @Date: 2025-06-11 15:25:37
# @LastEditTime: 2025-06-13 20:28:32
# @FilePath: /image_denoise/main.py
# @Description: 主函数,执行图像去噪
import os
import time
from enum import Enum
from tkinter import Image
import numpy as np
import matplotlib.pyplot as plt
from Method import Method
from TotalVariation import TV_L1, TV_L2
from pnp import PnP_BM3D, PnP_CNN
def load_img(path: str):
"""
加载图像文件
Args:
path (str): 图像文件路径
Returns:
np.ndarray: 图像数据
"""
from PIL import Image
img = Image.open(path)
return img
def noise_img(img: np.ndarray, noise_sigma: float):
"""
添加高斯噪声到图像
Args:
img (np.ndarray): 输入图像
noise_level (float): 噪声水平
Returns:
np.ndarray: 添加噪声后的图像
"""
temp_image = np.float64(np.copy(img))
np.random.seed(42)
noise = np.random.normal(0, noise_sigma, temp_image.shape) * 255
noisy_image = np.zeros(temp_image.shape, np.float64)
if len(temp_image.shape) == 2:
noisy_image = temp_image + noise
else:
noisy_image[:,:,0] = temp_image[:,:,0] + noise
noisy_image[:,:,1] = temp_image[:,:,1] + noise
noisy_image[:,:,2] = temp_image[:,:,2] + noise
"""
print('min,max = ', np.min(noisy_image), np.max(noisy_image))
print('type = ', type(noisy_image[0][0][0]))
"""
return noisy_image
def save_img(img: np.ndarray, path: str):
"""
保存图像文件
Args:
img (np.ndarray): 图像数据
path (str): 保存路径
"""
from PIL import Image
arr = np.asarray(img, dtype=np.uint8)
if arr.ndim == 3 and arr.shape[2] == 1:
arr = arr.squeeze(axis=2) # 变成 (H, W)
img = Image.fromarray(arr)
img.save(path)
class Methods(Enum):
TV_L1 = "TV-L1"
TV_L2 = "TV-L2"
PnP_BM3D = "PnP-BM3D"
PnP_CNN = "PnP-CNN"
def main():
"""
主函数,执行图像去噪
"""
# 设置去噪方法
method = Methods.PnP_CNN
# 加载图像
ori_img = load_img('image.png')
# 转为单通道灰度图像,shape为(height, width, 1)
gray_img = np.asarray(ori_img.convert('L'), np.float64)
# 添加噪声
noisy_img = noise_img(np.asarray(gray_img, np.float64), noise_sigma=0.07)
# 创建TV去噪对象
input_img = np.astype(noisy_img, np.float64)
denoised_method = None
if method == Methods.TV_L1:
denoised_method = TV_L1(input_img, lambda_=1, rho=0.07, max_iter=1000, cost_threshold=0.1)
elif method == Methods.TV_L2:
denoised_method = TV_L2(input_img, lambda_=7, rho=0.07, max_iter=1000, cost_threshold=1e-3)
elif method == Methods.PnP_BM3D:
denoised_method = PnP_BM3D(input_img, lambda_=1, rho=0.01, max_iter=30, cost_threshold=1e-3)
elif method == Methods.PnP_CNN:
denoised_method = PnP_CNN(input_img, lambda_=1, rho=0.07, max_iter=100, cost_threshold=1e-3)
# 确保去噪方法已实现
assert denoised_method is not None, "去噪方法未实现"
assert isinstance(denoised_method, Method), "去噪方法必须是Method的子类"
print(f"使用去噪方法: {method.value}")
# 执行去噪
denoised_img = denoised_method.solve()
# 残差
residual = np.abs(denoised_img - input_img)
# 对比
plt.figure(figsize=(16, 4))
plt.subplot(1, 4, 1)
plt.title('Original Image')
plt.imshow(gray_img, cmap='gray')
plt.axis('off')
plt.subplot(1, 4, 2)
plt.title('Noisy Image')
plt.imshow(noisy_img, cmap='gray')
plt.axis('off')
plt.subplot(1, 4, 3)
plt.title(f'Denoised Image ({method.value})')
plt.imshow(denoised_img, cmap='gray')
plt.axis('off')
plt.subplot(1, 4, 4)
plt.title('Residual')
plt.imshow(residual, cmap='gray')
plt.axis('off')
# 保存对比图
plt.savefig('denoised_comparison.png')
plt.show()
def test():
"""
测试不同的去噪方法和参数,并保存结果。
"""
# 1. 准备图像
try:
ori_img = load_img('image.png')
except FileNotFoundError:
print("Error: 'image.png' not found. Please ensure the image file is in the correct directory.")
# Create a dummy image for testing if not found
ori_img = Image.new('L', (256, 256), color=128)
ori_img_arr = np.array(ori_img)
# Add some shapes to the dummy image
ori_img_arr[64:192, 64:192] = 200
ori_img_arr[100:156, 100:156] = 50
ori_img = Image.fromarray(ori_img_arr)
print("Created a dummy test image.")
gray_img = np.asarray(ori_img.convert('L'), np.float64)
noisy_img = noise_img(np.asarray(gray_img, np.float64), noise_sigma=0.07)
input_img = np.astype(noisy_img, np.float64)
# 2. 定义测试配置
test_configs = [
# TV-L1 configs
{'method': Methods.TV_L1, 'params': {'lambda_': 0.8, 'rho': 0.05}},
{'method': Methods.TV_L1, 'params': {'lambda_': 1.0, 'rho': 0.07}},
# TV-L2 configs
{'method': Methods.TV_L2, 'params': {'lambda_': 6, 'rho': 0.05}},
{'method': Methods.TV_L2, 'params': {'lambda_': 7, 'rho': 0.07}},
# PnP-BM3D configs
{'method': Methods.PnP_BM3D, 'params': {'lambda_': 0.8, 'rho': 0.01, 'max_iter': 30}},
{'method': Methods.PnP_BM3D, 'params': {'lambda_': 1.0, 'rho': 0.03, 'max_iter': 30}},
# PnP-CNN configs
{'method': Methods.PnP_CNN, 'params': {'lambda_': 0.7, 'rho': 0.05, 'max_iter': 50}},
{'method': Methods.PnP_CNN, 'params': {'lambda_': 1.0, 'rho': 0.07, 'max_iter': 50}},
]
# 创建一个目录来保存所有测试结果
output_dir = "test_results"
os.makedirs(output_dir, exist_ok=True)
print(f"Saving results to '{output_dir}/' directory.")
# 3. 循环遍历配置进行测试
for config in test_configs:
method_enum = config['method']
params = config['params']
method_name = method_enum.value
# 动态构建方法实例
denoised_method = None
common_params = {'max_iter': 1000, 'cost_threshold': 1e-3}
# Update with specific params from config
common_params.update(params)
if method_enum == Methods.TV_L1:
denoised_method = TV_L1(input_img, **common_params)
elif method_enum == Methods.TV_L2:
denoised_method = TV_L2(input_img, **common_params)
elif method_enum == Methods.PnP_BM3D:
denoised_method = PnP_BM3D(input_img, **common_params)
elif method_enum == Methods.PnP_CNN:
denoised_method = PnP_CNN(input_img, **common_params)
# 检查方法是否已实现
assert denoised_method is not None, f"Denoising method {method_name} not implemented."
assert isinstance(denoised_method, Method), "Denoising method must be a subclass of Method."
# 生成唯一的文件名
param_str = '_'.join([f"{k}-{v}" for k, v in params.items()])
base_filename = f"{method_name}_{param_str}"
print(f"\n--- Testing: {method_name} with params: {params} ---")
# 执行去噪并计时
start_time = time.time()
denoised_img = denoised_method.solve()
end_time = time.time()
execution_time = end_time - start_time
# 获取最终结果数据
final_cost = denoised_method._get_cost()
final_iter = denoised_method.iter
residual = np.abs(denoised_img - input_img)
# 保存结果
# a) 保存结果数据到文本文件
results_filepath = os.path.join(output_dir, f"{base_filename}_data.txt")
with open(results_filepath, "w") as f:
f.write(f"Method: {method_name}\n")
f.write(f"Parameters: {common_params}\n")
f.write(f"Execution Time (s): {execution_time:.2f}\n")
f.write(f"Final Iterations: {final_iter}\n")
f.write(f"Final Cost: {final_cost:.6f}\n")
# b) 保存去噪后的图像
denoised_img_filepath = os.path.join(output_dir, f"{base_filename}_denoised.png")
save_img(denoised_img, denoised_img_filepath)
# c) 绘制并保存对比图
fig_title = f'Method: {method_name} | Params: ' + ', '.join([f"{k}={v}" for k, v in params.items()])
plt.figure(figsize=(20, 5))
plt.suptitle(fig_title, fontsize=16)
plt.subplot(1, 4, 1)
plt.title('Original Image')
plt.imshow(gray_img, cmap='gray', vmin=0, vmax=255)
plt.axis('off')
plt.subplot(1, 4, 2)
plt.title('Noisy Image')
plt.imshow(noisy_img, cmap='gray', vmin=0, vmax=255)
plt.axis('off')
plt.subplot(1, 4, 3)
plt.title(f'Denoised Image')
plt.imshow(denoised_img, cmap='gray', vmin=0, vmax=255)
plt.axis('off')
plt.subplot(1, 4, 4)
plt.title('Residual')
plt.imshow(residual, cmap='gray')
plt.axis('off')
comparison_filepath = os.path.join(output_dir, f"{base_filename}_comparison.png")
plt.savefig(comparison_filepath)
plt.close() # 关闭图像以释放内存
print("\n--- All tests completed successfully. ---")
if __name__ == "__main__":
# main()
test()