-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
89 lines (70 loc) · 2.59 KB
/
inference.py
File metadata and controls
89 lines (70 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from ns3_facodec.facodec import FACodecEncoder, FACodecDecoder
import torch
import librosa
import soundfile as sf
from huggingface_hub import hf_hub_download
def load_models():
fa_encoder = FACodecEncoder(
ngf=32,
up_ratios=[2, 4, 5, 5],
out_channels=256,
)
fa_decoder = FACodecDecoder(
in_channels=256,
upsample_initial_channel=1024,
ngf=32,
up_ratios=[5, 5, 4, 2],
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=256,
codebook_dim=8,
codebook_size_prosody=10,
codebook_size_content=10,
codebook_size_residual=10,
use_gr_x_timbre=True,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
)
# encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
# decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
encoder_ckpt = "checkpoint/ns3_facodec_encoder.bin"
decoder_ckpt = "checkpoint/ns3_facodec_decoder.bin"
fa_encoder.load_state_dict(torch.load(encoder_ckpt))
fa_decoder.load_state_dict(torch.load(decoder_ckpt))
fa_encoder.eval()
fa_decoder.eval()
print('Load FACodec Succeed !')
return fa_encoder, fa_decoder
if __name__ == '__main__':
fa_encoder, fa_decoder = load_models()
test_wav_path = "default_source_test.wav"
test_wav = librosa.load(test_wav_path, sr=16000)[0]
print(test_wav.shape)
test_wav = torch.from_numpy(test_wav).float()
test_wav = test_wav.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
# encode
enc_out = fa_encoder(test_wav)
print(enc_out.shape)
# quantize
vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
# latent after quantization
print(vq_post_emb.shape)
# codes
print("vq id shape:", vq_id.shape)
# get prosody code
prosody_code = vq_id[:1]
print("prosody code shape:", prosody_code.shape)
# get content code
cotent_code = vq_id[1:3]
print("content code shape:", cotent_code.shape)
# get residual code (acoustic detail codes)
residual_code = vq_id[3:]
print("residual code shape:", residual_code.shape)
# speaker embedding
print("speaker embedding shape:", spk_embs.shape)
# decode (recommand)
recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
print(recon_wav.shape)
sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)