-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMethod.py
More file actions
81 lines (69 loc) · 1.91 KB
/
Method.py
File metadata and controls
81 lines (69 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
# @Author:
# @Date: 2025-06-12 21:19:37
# @LastEditTime: 2025-06-13 16:17:32
# @FilePath: /image_denoise/Method.py
# @Description: 去噪方法基类
import numpy as np
from abc import ABC, abstractmethod
class Method(ABC):
"""
去噪方法基类
"""
def __init__(self, input: np.ndarray, max_iter=1000, cost_threshold=1e-6):
self.max_iter = max_iter # 最大迭代次数
self.cost_threshold = cost_threshold # 收敛容忍度
self.method = "method"
self.img = input # 确保图像是浮点型并归一化到[0, 1]
assert self.img.ndim == 2, "输入图像必须是2D数组"
# 图像尺寸
[self.height, self.width] = self.img.shape
print(f"输入图像尺寸: {self.img.shape}")
self.iter = 0 # 当前迭代次数
@abstractmethod
def _update(self):
"""
更新图像
Returns:
np.ndarray: 更新后的图像
"""
pass
def set_img(self, img: np.ndarray):
"""
设置输入图像
Args:
img (np.ndarray): 输入图像
"""
assert img.ndim == 2, "输入图像必须是2D数组"
self.img = img.astype(np.float64)
self.iter = 0 # 重置迭代次数
def solve(self) -> np.ndarray:
"""
求解优化问题
Returns:
np.ndarray: 去噪后的图像
"""
print(f"开始{self.method}迭代...")
diff = 100000.0
cost_prev = 1e5
self.iter = 0 # 重置迭代次数
show_iter = 10 if self.max_iter < 200 else 100
while diff > self.cost_threshold and self.iter < self.max_iter:
self.iter += 1
self._update() # 迭代
# 计算当前代价函数
cost_cur = self._get_cost()
diff = abs(cost_cur - cost_prev)
cost_prev = cost_cur
if self.iter % show_iter == 0:
print(f'迭代 {self.iter}: Diff = {diff:.6f}, Cost = {cost_cur:.6f}')
print(f'收敛完成,总迭代次数: {self.iter}')
return self.u
@abstractmethod
def _get_cost(self) -> float:
"""
计算当前代价函数
Returns:
float: 当前代价
"""
pass