-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcodebook_clustering.py
More file actions
146 lines (111 loc) · 4.74 KB
/
codebook_clustering.py
File metadata and controls
146 lines (111 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from collections import Counter
import librosa
import numpy as np
from sklearn.cluster import KMeans
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import umap
# 假设原始codebook大小为 [1024, embedding_dim]
# 这里示例中的codebook是来自预训练FACodec模型的向量
from inference import load_models
def create_reduced_codebook(original_codebook, num_clusters=64, random_state=42):
"""
将原始大codebook聚类为更小的codebook
Args:
original_codebook: 原始codebook张量, 形状为 [1024, embedding_dim]
num_clusters: 聚类后的codebook大小
random_state: 随机种子,确保结果可复现
Returns:
reduced_codebook: 聚类后的codebook
mapping: 原始codebook索引到新codebook索引的映射
"""
# 确保输入是numpy数组
if isinstance(original_codebook, torch.Tensor):
original_codebook = original_codebook.detach().cpu().numpy()
embedding_dim = original_codebook.shape[1]
# 使用K-means聚类
kmeans = KMeans(n_clusters=num_clusters, random_state=random_state, n_init=10)
cluster_labels = kmeans.fit_predict(original_codebook)
# 新的codebook是K-means的聚类中心
reduced_codebook = kmeans.cluster_centers_
# 创建原始索引到新索引的映射
mapping = {i: cluster_labels[i] for i in range(len(cluster_labels))}
# 可视化聚类结果(可选)
# visualize_clustering(original_codebook, reduced_codebook, cluster_labels, num_clusters)
return torch.tensor(reduced_codebook, dtype=torch.float32), mapping
def visualize_clustering(original_codebook, reduced_codebook, cluster_labels, num_clusters):
"""可视化聚类结果"""
# 使用t-SNE降维以便可视化
# tsne = TSNE(n_components=2, random_state=42)
# original_2d = tsne.fit_transform(original_codebook)
# centers_2d = tsne.transform(reduced_codebook)
umap_model = umap.UMAP(n_components=2, random_state=42)
original_2d = umap_model.fit_transform(original_codebook)
centers_2d = umap_model.transform(reduced_codebook)
# 绘制原始codebook点和聚类中心
plt.figure(figsize=(12, 10))
plt.scatter(original_2d[:, 0], original_2d[:, 1], c=cluster_labels,
cmap='viridis', alpha=0.6, s=30)
plt.scatter(centers_2d[:, 0], centers_2d[:, 1], c=range(num_clusters),
cmap='viridis', marker='X', s=200, edgecolors='k')
plt.title(f'Codebook Clustering: 1024 → {num_clusters}')
plt.savefig(f'codebook_clustering_{num_clusters}.png')
plt.close()
# 使用示例
# 假设我们有一个FACodec模型,并从中提取codebook
def get_reduced_codebooks_from_facodec(facodec_model, num_clusters=64):
"""从FACodec模型中提取并聚类所有codebook"""
reduced_codebooks = {}
mappings = {}
# 假设FACodec模型中有三个VQ层,分别对应Prosody、Content和Detail
for name, codebook_name in [
("prosody", "quantizer.0.layers.0.codebook.weight"),
("content", "quantizer.1.layers.1.codebook.weight"),
("detail", "quantizer.2.layers.0.codebook.weight")
]:
# 获取原始codebook
# 注意:实际路径需要根据FACodec模型结构调整
codebook = get_attr_by_path(facodec_model, codebook_name)
# 聚类
reduced_codebook, mapping = create_reduced_codebook(
codebook, num_clusters=num_clusters
)
reduced_codebooks[name] = reduced_codebook
mappings[name] = mapping
return reduced_codebooks, mappings
# 辅助函数:根据字符串路径获取模型属性
def get_attr_by_path(obj, path):
"""根据点分隔的路径获取对象属性"""
for attr in path.split('.'):
if hasattr(obj, attr):
obj = getattr(obj, attr)
else:
return None
return obj
def vote_with_mode(lst, pool=4):
new_lst = []
for i in range(0, len(lst), pool):
group = lst[i:i+pool]
if not group:
continue
mode = Counter(group).most_common(1)[0][0] # 计算众数
new_lst.append(mode)
return new_lst
def replace_with_mode(lst, pool=4):
validity_count = 0
total_groups = 0
new_lst = []
for i in range(0, len(lst), pool):
group = lst[i:i+pool] # 取每 8 个数为一组
if not group:
continue
mode = Counter(group).most_common(1)[0][0] # 计算众数
new_lst.extend([mode] * len(group)) # 用众数替代整个组
if check_majority(group, mode):
validity_count += 1
total_groups += 1
print(f"Groups with majority: {validity_count}/{total_groups}")
return new_lst
def check_majority(group, num):
return group.count(num) >= len(group) // 2 # 判断是否至少占半数