diff --git a/meldataset.py b/meldataset.py index 44b0bf45a..a2e423f52 100644 --- a/meldataset.py +++ b/meldataset.py @@ -62,9 +62,9 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, y = y.squeeze(1) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], - center=center, pad_mode='reflect', normalized=False, onesided=True) + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) - spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + spec = torch.abs(spec) spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) spec = spectral_normalize_torch(spec) diff --git a/train.py b/train.py index 191c9f478..8930dcd8a 100644 --- a/train.py +++ b/train.py @@ -119,7 +119,11 @@ def train(rank, a, h, warm_start): if h.num_gpus > 1: train_sampler.set_epoch(epoch) - for i, batch in enumerate(train_loader): + processedData = [] + for batch in train_loader: + processedData.append(batch) + + for i, batch in enumerate(processedData): if rank == 0: start_b = time.time() x, y, _, y_mel = batch