diff --git a/stopes/eval/auto_pcp/audio_comparator.py b/stopes/eval/auto_pcp/audio_comparator.py index 88819a2..f2d75cf 100644 --- a/stopes/eval/auto_pcp/audio_comparator.py +++ b/stopes/eval/auto_pcp/audio_comparator.py @@ -14,14 +14,15 @@ import typing as tp from pathlib import Path +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import tqdm -from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model from stopes.core import utils from stopes.modules.speech.audio_load_utils import parallel_audio_read +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model logger = logging.getLogger(__name__) @@ -255,13 +256,17 @@ def encode_audios( if use_cuda: model.cuda() - itr = parallel_audio_read( - lines=audio_paths, - column_offset=0, - sampling_factor=sampling_factor, - num_process=num_process, - collapse_channels=True, - ) + if isinstance(audio_paths[0], np.ndarray): + # if the input list consists of waveforms, we don't need to read them + itr = [(None, wav) for wav in audio_paths] + else: + itr = parallel_audio_read( + lines=audio_paths, + column_offset=0, + sampling_factor=sampling_factor, + num_process=num_process, + collapse_channels=True, + ) if progress: itr = tqdm(itr, total=len(audio_paths)) @@ -397,6 +402,25 @@ def compare_audio_pairs( num_process=num_process, )[:, 0] logger.info("Comparing source and target embedding") + preds = get_comparator_preds( + model=model, + src_emb=src_emb, + tgt_emb=tgt_emb, + batch_size=batch_size, + symmetrize=symmetrize, + ) + return list(preds) + + +def get_comparator_preds( + model, src_emb, tgt_emb, batch_size: int = 16, symmetrize: bool = True +) -> np.ndarray: + if isinstance(src_emb, np.ndarray): + src_emb = torch.tensor(src_emb) + if isinstance(tgt_emb, np.ndarray): + tgt_emb = torch.tensor(tgt_emb) + src_emb = src_emb.squeeze(1) + tgt_emb = tgt_emb.squeeze(1) preds = ( get_model_pred( model, @@ -421,4 +445,4 @@ def compare_audio_pairs( .numpy() ) preds = (preds2 + preds) / 2 - return list(preds) + return preds diff --git a/stopes/modules/evaluation/compare_audio_module.py b/stopes/modules/evaluation/compare_audio_module.py index e08eac2..dd511fd 100644 --- a/stopes/modules/evaluation/compare_audio_module.py +++ b/stopes/modules/evaluation/compare_audio_module.py @@ -10,12 +10,19 @@ from dataclasses import dataclass from pathlib import Path +import numpy as np import torch from omegaconf.omegaconf import MISSING from stopes.core.stopes_module import Requirements, StopesModule -from stopes.eval.auto_pcp.audio_comparator import compare_audio_pairs +from stopes.eval.auto_pcp.audio_comparator import ( + Comparator, + compare_audio_pairs, + encode_audios, + get_comparator_preds, +) from stopes.utils.web import cached_file_download +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model logger = logging.getLogger(__name__) @@ -134,3 +141,49 @@ def validate( iteration_index: int = 0, ) -> bool: return output.exists() + + def load_models(self): + """Loading models once, to avoid future recalculation""" + use_cuda = self.config.use_cuda and torch.cuda.is_available() + + self.audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + self.config.encoder_path + ) + self.audio_encoder = Wav2Vec2Model.from_pretrained(self.config.encoder_path) + if use_cuda: + self.audio_encoder.cuda() + + comparator_path = self.config.comparator_path + if comparator_path is None: + comparator_path = cached_file_download( + self.config.comparator_url, # type: ignore + self.config.comparator_save_name, # type: ignore + unzip=True, + ) + self.comparator = Comparator.load(comparator_path, use_gpu=use_cuda) + + def embed_audios( + self, inputs: tp.List[tp.Union[str, np.ndarray]], progress: bool = True + ) -> torch.Tensor: + """ + Encode the audios to representations suitable for the comparator model. + This is useful for expressive alignment, where for each encoded audio we compute multiple comparisons. + """ + return encode_audios( + audio_paths=inputs, + model=self.audio_encoder, + fex=self.audio_feature_extractor, + pick_layer=self.config.pick_layer, + batch_size=self.config.batch_size, + num_process=self.config.num_process, + progress=progress, + ).squeeze(1) + + def compare_embeddings(self, src_emb, tgt_emb): + return get_comparator_preds( + model=self.comparator, + src_emb=src_emb, + tgt_emb=tgt_emb, + batch_size=self.config.batch_size, + symmetrize=self.config.symmetrize, + )