-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
step 1 torch2onnx:
#!/ssd4/exec/huangps/anaconda3/envs/melgan/bin/python
import torch
import torchvision
import numpy as np
from model.generator import Generator
from utils.hparams import HParam, load_hparam_str
from utils.pqmf import PQMF
import wave
checkpoint = torch.load('./chkpt/hps/hps_13efcb4_0600.pt')
hp = load_hparam_str(checkpoint['hp_str'])
vocoder = Generator(hp.audio.n_mel_channels, hp.model.n_residual_layers,
ratios=hp.model.generator_ratio, mult = hp.model.mult,
out_band = hp.model.out_channels).cuda()
vocoder.load_state_dict(checkpoint['model_g'])
vocoder.eval(inference=False)
# vocoder.inference(mel)
mel = np.load("/ssd5/exec/huangps/melgan/datasets/LJSpeech-1.1/mels/LJ001-0001.npy")
mel = torch.from_numpy(mel).to(device='cuda', dtype=torch.float32)
mel = mel.unsqueeze(0)
dummy_input = mel
input_names = [ "mel" ]
output_names = [ "output" ]
dynamic_axes = {
"mel" : {0: "batch_size", 2: "seq_len"}
}
torch.onnx.export(vocoder, dummy_input, "melgan.onnx", verbose=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
# MAX_WAV_VALUE = 32768.0
# with torch.no_grad():
# mel = mel.detach()
# if len(mel.shape) == 2:
# mel = mel.unsqueeze(0)
# mel = mel.cuda()
# audio = vocoder.inference(mel)
# # For multi-band inference
# if hp.model.out_channels > 1:
# pqmf = PQMF()
# audio = pqmf.synthesis(audio).view(-1)
# audio = audio.squeeze() # collapse all dimension except time axis
# audio = audio[:-(hp.audio.hop_length*10)]
# audio = MAX_WAV_VALUE * audio
# audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
# audio = audio.short()
# audio = audio.cpu().detach().numpy()
# print(audio.shape)
# print(audio[:10])
# with wave.open('1.wav', 'wb') as wavfile:
# wavfile.setparams((1, 2, 22050, 0, 'NONE', 'NONE'))
# wavfile.writeframes(audio)Metadata
Metadata
Assignees
Labels
No labels