Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions stopes/eval/auto_pcp/audio_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import typing as tp
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from stopes.core import utils
from stopes.modules.speech.audio_load_utils import parallel_audio_read
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -255,13 +256,17 @@ def encode_audios(
if use_cuda:
model.cuda()

itr = parallel_audio_read(
lines=audio_paths,
column_offset=0,
sampling_factor=sampling_factor,
num_process=num_process,
collapse_channels=True,
)
if isinstance(audio_paths[0], np.ndarray):
# if the input list consists of waveforms, we don't need to read them
itr = [(None, wav) for wav in audio_paths]
else:
itr = parallel_audio_read(
lines=audio_paths,
column_offset=0,
sampling_factor=sampling_factor,
num_process=num_process,
collapse_channels=True,
)
if progress:
itr = tqdm(itr, total=len(audio_paths))

Expand Down Expand Up @@ -397,6 +402,25 @@ def compare_audio_pairs(
num_process=num_process,
)[:, 0]
logger.info("Comparing source and target embedding")
preds = get_comparator_preds(
model=model,
src_emb=src_emb,
tgt_emb=tgt_emb,
batch_size=batch_size,
symmetrize=symmetrize,
)
return list(preds)


def get_comparator_preds(
model, src_emb, tgt_emb, batch_size: int = 16, symmetrize: bool = True
) -> np.ndarray:
if isinstance(src_emb, np.ndarray):
src_emb = torch.tensor(src_emb)
if isinstance(tgt_emb, np.ndarray):
tgt_emb = torch.tensor(tgt_emb)
src_emb = src_emb.squeeze(1)
tgt_emb = tgt_emb.squeeze(1)
preds = (
get_model_pred(
model,
Expand All @@ -421,4 +445,4 @@ def compare_audio_pairs(
.numpy()
)
preds = (preds2 + preds) / 2
return list(preds)
return preds
55 changes: 54 additions & 1 deletion stopes/modules/evaluation/compare_audio_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from omegaconf.omegaconf import MISSING

from stopes.core.stopes_module import Requirements, StopesModule
from stopes.eval.auto_pcp.audio_comparator import compare_audio_pairs
from stopes.eval.auto_pcp.audio_comparator import (
Comparator,
compare_audio_pairs,
encode_audios,
get_comparator_preds,
)
from stopes.utils.web import cached_file_download
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,3 +141,49 @@ def validate(
iteration_index: int = 0,
) -> bool:
return output.exists()

def load_models(self):
"""Loading models once, to avoid future recalculation"""
use_cuda = self.config.use_cuda and torch.cuda.is_available()

self.audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
self.config.encoder_path
)
self.audio_encoder = Wav2Vec2Model.from_pretrained(self.config.encoder_path)
if use_cuda:
self.audio_encoder.cuda()

comparator_path = self.config.comparator_path
if comparator_path is None:
comparator_path = cached_file_download(
self.config.comparator_url, # type: ignore
self.config.comparator_save_name, # type: ignore
unzip=True,
)
self.comparator = Comparator.load(comparator_path, use_gpu=use_cuda)

def embed_audios(
self, inputs: tp.List[tp.Union[str, np.ndarray]], progress: bool = True
) -> torch.Tensor:
"""
Encode the audios to representations suitable for the comparator model.
This is useful for expressive alignment, where for each encoded audio we compute multiple comparisons.
"""
return encode_audios(
audio_paths=inputs,
model=self.audio_encoder,
fex=self.audio_feature_extractor,
pick_layer=self.config.pick_layer,
batch_size=self.config.batch_size,
num_process=self.config.num_process,
progress=progress,
).squeeze(1)

def compare_embeddings(self, src_emb, tgt_emb):
return get_comparator_preds(
model=self.comparator,
src_emb=src_emb,
tgt_emb=tgt_emb,
batch_size=self.config.batch_size,
symmetrize=self.config.symmetrize,
)