diff --git a/.github/workflows/lint_and_test.yml b/.github/workflows/lint_and_test.yml index 232d25b..dfb0909 100644 --- a/.github/workflows/lint_and_test.yml +++ b/.github/workflows/lint_and_test.yml @@ -28,8 +28,7 @@ jobs: run: | sudo apt-get install libsndfile1 python -m pip install --upgrade pip - pip install -r requirements-dev.txt - pip install -e . + pip install -e ".[dev]" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -58,9 +57,7 @@ jobs: run: | sudo apt-get install libsndfile1 python -m pip install --upgrade pip - pip install --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.5.1/cpu -r requirements.txt - pip install -r requirements-dev.txt - pip install -e . + pip install -e ".[dev]" --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cpu - name: mypy run: mypy --install-types --non-interactive ./ --cache-dir=.mypy_cache/ @@ -81,8 +78,7 @@ jobs: - name: Install dependencies run: | sudo apt-get install libsndfile1 - pip install -r requirements-dev.txt - pip install -e . + pip install -e ".[dev]" - name: pytest_unit run: pytest -s -v tests/unit_tests/ @@ -104,8 +100,7 @@ jobs: run: | sudo apt-get install libsndfile1 python -m pip install --upgrade pip - pip install -r requirements-dev.txt - pip install -e . + pip install -e ".[dev]" - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main with: diff --git a/LICENSE.md b/LICENSE.md index 2e93333..d69c5b8 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -68,5 +68,5 @@ BUT BEWARE, the following speech encoders are released under a non commercial li | vie | vietnamese | https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.vie.pt | | yue | yue | https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.yue.pt | -The SONAR text encoder and decoder, as well as the BLASER 2.0 models, are released under the same non-commercial +The SONAR text encoder and decoder, as well as the BLASER 2.0 models, are released under the same non-commercial license ([NC_MODEL_LICENSE](NC_MODEL_LICENSE.md)). diff --git a/NC_MODEL_LICENSE.md b/NC_MODEL_LICENSE.md index 2a96631..b7f6692 100644 --- a/NC_MODEL_LICENSE.md +++ b/NC_MODEL_LICENSE.md @@ -58,7 +58,7 @@ exhaustive, and do not form part of our licenses. such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More_considerations - for the public: + for the public: wiki.creativecommons.org/Considerations_for_licensees ======================================================================= diff --git a/README.md b/README.md index 525d7fa..b47218b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [[Paper]](https://ai.meta.com/research/publications/sonar-sentence-level-multimodal-and-language-agnostic-representations/) [[Demo]](#usage) -We introduce SONAR, a new multilingual and multimodal fixed-size sentence embedding space, with a full suite of speech and text encoders and decoders. It substantially outperforms existing sentence embeddings such as LASER3 and LabSE on the xsim and xsim++ multilingual similarity search tasks. +We introduce SONAR, a new multilingual and multimodal fixed-size sentence embedding space, with a full suite of speech and text encoders and decoders. It substantially outperforms existing sentence embeddings such as LASER3 and LabSE on the xsim and xsim++ multilingual similarity search tasks. Speech segments can be embedded in the same SONAR embedding space using language-specific speech encoders trained in a teacher-student setting on speech transcription data. We also provide a single text decoder, which allows us to perform text-to-text and speech-to-text machine translation, including for zero-shot language and modality combinations. @@ -37,7 +37,7 @@ pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/ ``` If [fairseq2](https://github.com/facebookresearch/fairseq2) does not provide a build for your machine, check the readme of that project to build it locally. -We recommend installing SONAR only after you have a correct version of `fairseq2` installed. Note that SONAR currently relies on the stable version of fairseq2 0.4.5 (with minor variations possible). +We recommend installing SONAR only after you have a correct version of `fairseq2` installed. Note that SONAR currently relies on the stable version of `fairseq2>=0.5.2` (with minor variations possible). If you want to install SONAR manually, you can install it localy: @@ -46,6 +46,14 @@ pip install --upgrade pip pip install -e . ``` +### Versions +Unfortunately, SONAR code is very much tied to fairseq2 code, and thus only specific version are compatible with each other: +- `sonar-space~=0.5.0` (the current version) requires `fairseq2>=0.5.2` +- `sonar-space~=0.4.0` required `fairseq2~=0.4.0` +- `sonar-space~=0.2.0` required `fairseq2~=0.2.0` + +In the future, when the `fairseq2` interface stabilizes, we hope to keep the version dependencies less loosely coupled. + ## Usage fairseq2 will automatically download models into your `$TORCH_HOME/hub` directory upon using the commands below. @@ -150,7 +158,7 @@ assert sr == 16000, "Sample rate should be 16kHz" s2t_model.predict([inp], target_lang="eng_Latn") # ['Television reports show white smoke coming from the plant.'] -# passing multiple wav files +# passing multiple wav files s2t_model.predict(["./tests/integration_tests/data/audio_files/audio_1.wav", "./tests/integration_tests/data/audio_files/audio_2.wav"], target_lang="eng_Latn") # ['Television reports show white smoke coming from the plant.', @@ -161,8 +169,8 @@ s2t_model.predict(["./tests/integration_tests/data/audio_files/audio_1.wav", ### Predicting sentence similarity with BLASER 2.0 models BLASER 2.0 is a family of models for automatic evaluation of machine translation quality based on SONAR embeddings. -They predict [cross-lingual semantic similarity](https://github.com/facebookresearch/fairseq/tree/nllb/examples/nllb/human_XSTS_eval) -between the translation and the source (optionally, also using a reference translation). +They predict [cross-lingual semantic similarity](https://github.com/facebookresearch/fairseq/tree/nllb/examples/nllb/human_XSTS_eval) +between the translation and the source (optionally, also using a reference translation). ```Python from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline @@ -182,7 +190,7 @@ with torch.inference_mode(): ``` Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggingface.co/facebook/blaser-2.0-ref), -[facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe). +[facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe). ### Classifying the toxicity of sentences with MuTox @@ -237,6 +245,10 @@ See more complete demo notebooks : * [sonar speech2text and other data pipeline examples](examples/inference_pipelines.ipynb) * [sonar bilingual document alignment with sonar text similarity](examples/bilingual_document.ipynb) +### Troubleshooting + +- In case of errors like `fairseq2.assets.card.AssetCardError: Model checkpoint of the blaser_2_0_qe asset card cannot be loaded`, try removing the fairseq2 assets cache (located in `~/.cache/fairseq2`); it might be that some of the downloaded model checkpoints are invalid. + ## Supported languages and download links The SONAR text encoder & decoder supports 200 languages. SONAR speech encoders support 37 languages. @@ -554,6 +566,6 @@ See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. SONAR code is released under the MIT license (see [CODE_LICENSE](CODE_LICENSE.md)). -Some of SONAR models are released with the same MIT license, BUT BEWARE, +Some of SONAR models are released with the same MIT license, BUT BEWARE, some of them are released under a non commercial license (see [NC_MODEL_LICENSE](NC_MODEL_LICENSE.md)). Please refer to [LICENSE](LICENSE.md) for the details. diff --git a/examples/data/eng_flores200_dev_sample.tsv b/examples/eng_flores200_dev_sample.tsv similarity index 100% rename from examples/data/eng_flores200_dev_sample.tsv rename to examples/eng_flores200_dev_sample.tsv diff --git a/examples/inference_pipelines.ipynb b/examples/inference_pipelines.ipynb index cc7acfb..eee95b8 100644 --- a/examples/inference_pipelines.ipynb +++ b/examples/inference_pipelines.ipynb @@ -12,24 +12,6 @@ "* Speech to Text translation" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Install sonar\n", - "\n", - "if sonar is not yet installed, install it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --quiet sonar-space" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -61,11 +43,13 @@ "metadata": {}, "outputs": [], "source": [ - "from sonar.models.sonar_speech.loader import load_sonar_speech_model\n", + "from fairseq2 import init_fairseq2\n", + "from fairseq2.data.text import get_text_tokenizer_hub\n", + "\n", + "from sonar.models.sonar_speech import get_sonar_speech_encoder_hub\n", "from sonar.models.sonar_text import (\n", - " load_sonar_text_decoder_model,\n", - " load_sonar_text_encoder_model,\n", - " load_sonar_tokenizer,\n", + " get_sonar_text_decoder_hub,\n", + " get_sonar_text_encoder_hub,\n", ")" ] }, @@ -75,7 +59,17 @@ "metadata": {}, "outputs": [], "source": [ - "speech_encoder_model = load_sonar_speech_model(\n", + "init_fairseq2()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "speech_encoder_hub = get_sonar_speech_encoder_hub()\n", + "speech_encoder_model = speech_encoder_hub.load(\n", " \"sonar_speech_encoder_eng\", device=device\n", ").eval()" ] @@ -86,7 +80,8 @@ "metadata": {}, "outputs": [], "source": [ - "text_encoder_model = load_sonar_text_encoder_model(\n", + "text_encoder_hub = get_sonar_text_encoder_hub()\n", + "text_encoder_model = text_encoder_hub.load(\n", " \"text_sonar_basic_encoder\", device=device\n", ").eval()" ] @@ -97,7 +92,8 @@ "metadata": {}, "outputs": [], "source": [ - "text_decoder_model = load_sonar_text_decoder_model(\n", + "text_decoder_hub = get_sonar_text_decoder_hub()\n", + "text_decoder_model = text_decoder_hub.load(\n", " \"text_sonar_basic_decoder\", device=device\n", ").eval()" ] @@ -109,7 +105,8 @@ "outputs": [], "source": [ "# tokenizer is compatible with nllb tokenizer logic already\n", - "text_tokenizer = load_sonar_tokenizer(\"text_sonar_basic_encoder\")" + "text_tokenizer_hub = get_text_tokenizer_hub()\n", + "text_tokenizer = text_tokenizer_hub.load(\"text_sonar_basic_encoder\")" ] }, { @@ -260,7 +257,7 @@ } ], "source": [ - "data_source = \"./data/eng_flores200_dev_sample.tsv\"\n", + "data_source = \"./eng_flores200_dev_sample.tsv\"\n", "text_emb = text_embedding_pipeline.predict(data_source, source_lang=\"eng_Latn\")\n", "text_emb" ] @@ -297,11 +294,18 @@ ")\n", "text_translation" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "sonar_fairseq2", "language": "python", "name": "python3" }, @@ -315,7 +319,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.8.17" }, "orig_nbformat": 4 }, diff --git a/examples/sonar_text_demo.ipynb b/examples/sonar_text_demo.ipynb index c793d74..763f846 100644 --- a/examples/sonar_text_demo.ipynb +++ b/examples/sonar_text_demo.ipynb @@ -1,21 +1,5 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Install the dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --quiet sonar-space seaborn pandas" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -219,14 +203,14 @@ " french_translated_sentences, [french_sentences], tokenize=\"flores200\"\n", ")\n", "print(\"*\" * 120)\n", - "print(\"english to french translation bleu score :\")\n", + "print(\"english to spanish translation bleu score :\")\n", "print(bleu_obj)\n", "\n", "bleu_obj = sacrebleu.corpus_bleu(\n", " english_translated_sentences, [english_sentences], tokenize=\"flores200\"\n", ")\n", "print(\"*\" * 120)\n", - "print(\"french to english translation bleu score :\")\n", + "print(\"spanish to english translation bleu score :\")\n", "print(bleu_obj)" ] }, diff --git a/huggingface_pipelines/text.py b/huggingface_pipelines/text.py index 4ebbc87..79a46f0 100644 --- a/huggingface_pipelines/text.py +++ b/huggingface_pipelines/text.py @@ -333,9 +333,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: ): all_embeddings = np.asarray(embeddings, dtype=self.config.dtype) all_decoded_texts = self.decode_embeddings(all_embeddings) - batch[ - f"{column}_{self.config.output_column_suffix}" - ] = all_decoded_texts + batch[f"{column}_{self.config.output_column_suffix}"] = ( + all_decoded_texts + ) elif all(isinstance(item, list) for item in embeddings): all_embeddings = np.vstack( [ @@ -351,9 +351,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: all_decoded_texts[start:end] for start, end in zip([0] + indices[:-1], indices) ] - batch[ - f"{column}_{self.config.output_column_suffix}" - ] = reconstructed_texts + batch[f"{column}_{self.config.output_column_suffix}"] = ( + reconstructed_texts + ) else: raise ValueError(f"Invalid input type for column {column}") logger.debug( @@ -490,9 +490,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Case: List of individual strings all_texts = batch[column] all_embeddings = self.encode_texts(all_texts) - batch[ - f"{column}_{self.config.output_column_suffix}" - ] = all_embeddings + batch[f"{column}_{self.config.output_column_suffix}"] = ( + all_embeddings + ) elif all(isinstance(item, list) for item in batch[column]): # Case: List of lists (sentences) all_sentences = [ @@ -513,9 +513,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: for start, end in zip([0] + indices[:-1], indices) ] - batch[ - f"{column}_{self.config.output_column_suffix}" - ] = sentence_embeddings + batch[f"{column}_{self.config.output_column_suffix}"] = ( + sentence_embeddings + ) else: raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index 531037f..729974d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,10 @@ classifiers=[ dependencies = [ # fairseq2 is installed with the cpu deps; if you want a gpu build, you need to install it manually. # see https://github.com/facebookresearch/fairseq2 - "fairseq2~=0.4.0", + # we require a relaxed version of fairseq2 so that the users can be flexible with it + # (depending on the other dependencies of their project) + "fairseq2>=0.5.2", + # "mpmath==1.3.0", # I am not sure if we need it "numpy>=1.21", "torch", "torchaudio", @@ -42,17 +45,19 @@ dependencies = [ "pytest-cov>=2.6.1", "coverage[toml]>=5.1", # Format - "black==22.3.0", + "black==25.1.0", "isort>=5.10.1", # Linters "mypy>=0.782", "pylint>=2.8.0", + "flake8", + "types-tqdm" ] cpu = [ - "torch==2.5.1+cpu", - "torchaudio==2.5.1+cpu", - "fairseq2n~=0.4.0", - "fairseq2~=0.4.0", + "torch==2.6.0+cpu", + "torchaudio==2.6.0+cpu", + "fairseq2n>=0.5.2", + "fairseq2>=0.5.2", ] hg = [ "transformers>=4.44.0", diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 0390385..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,13 +0,0 @@ -# Test -pytest>=4.3.0 -pytest-asyncio>=0.15.0 -pytest-cov>=2.6.1 -coverage[toml]>=5.1 -# Format -black==22.3.0 -isort>=5.10.1 -# Linter -mypy>=0.782 -pylint>=2.8.0 -flake8 -types-tqdm diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 62c4e04..0000000 --- a/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -fairseq2>=0.4.0 -mpmath==1.3.0 -numpy>=1.21 -torch -torchaudio -sox -soundfile -tqdm diff --git a/sonar/__init__.py b/sonar/__init__.py index 79c15e0..29e5685 100644 --- a/sonar/__init__.py +++ b/sonar/__init__.py @@ -8,143 +8,155 @@ SONAR provides a set of speech and text encoders for multilingual, multimodal semantic embedding. """ -from fairseq2 import setup_fairseq2 -from fairseq2.config_registry import ConfigProvider -from fairseq2.context import RuntimeContext -from fairseq2.data.text.tokenizers import TextTokenizerHandler -from fairseq2.models import ModelHandler -from fairseq2.setup import register_package_metadata_provider -from fairseq2.utils.file import TorchTensorLoader +from types import NoneType + +from fairseq2.composition.assets import register_package_assets +from fairseq2.composition.models import register_model_family +from fairseq2.composition.tokenizers import register_tokenizer_family +from fairseq2.runtime.dependency import DependencyContainer from sonar.models.blaser import ( + BLASER_FAMILY, BlaserConfig, - BlaserModelHandler, - register_blaser_configs, + BlaserModel, + _convert_blaser_checkpoint, + _create_blaser_model, + _register_blaser_configs, ) from sonar.models.laser2_text import ( + LASER2_FAMILY, Laser2Config, - Laser2ModelHandler, - Laser2TokenizerHandler, - register_laser2_configs, + LaserLstmEncoder, + _convert_laser2_checkpoint, + _create_laser2_model, + _load_laser2_tokenizer, + _register_laser2_configs, +) +from sonar.models.laser2_text.tokenizer import Laser2Tokenizer +from sonar.models.mutox import ( + MUTOX_FAMILY, + MutoxClassifier, + MutoxConfig, + _convert_mutox_checkpoint, + _create_mutox_model, + _register_mutox_configs, ) -from sonar.models.mutox import MutoxConfig, MutoxModelHandler, register_mutox_configs from sonar.models.sonar_speech import ( + SONAR_SPEECH_FAMILY, SonarSpeechEncoderConfig, - SonarSpeechEncoderHandler, - register_sonar_speech_encoder_configs, + SonarSpeechEncoderModel, + _convert_sonar_speech_checkpoint, + _create_sonar_speech_encoder_model, + _register_sonar_speech_encoder_configs, ) from sonar.models.sonar_text import ( + SONAR_TEXT_DECODER_FAMILY, + SONAR_TEXT_ENCODER_FAMILY, + ConditionalTransformerDecoderModel, SonarTextDecoderConfig, - SonarTextDecoderHandler, SonarTextEncoderConfig, - SonarTextEncoderHandler, - register_sonar_text_decoder_configs, - register_sonar_text_encoder_configs, + SonarTextTransformerEncoderModel, + _convert_sonar_text_decoder_checkpoint, + _convert_sonar_text_encoder_checkpoint, + _create_sonar_text_decoder_model, + _create_sonar_text_encoder_model, + _register_sonar_text_decoder_configs, + _register_sonar_text_encoder_configs, ) -__version__ = "0.4.0" +__version__ = "0.5.0" -def setup_fairseq2_extension(context: RuntimeContext) -> None: +def setup_fairseq2_extension(container: DependencyContainer) -> None: # Make sure that the default fairseq2 asset store can resolve cards under # the directory /cards. - register_package_metadata_provider(context, "sonar.cards") - - _register_models(context) - - _register_text_tokenizers(context) - - -def _register_models(context: RuntimeContext) -> None: - asset_download_manager = context.asset_download_manager - - tensor_loader = TorchTensorLoader(context.file_system, restrict=False) + register_package_assets(container, "sonar.cards") - registry = context.get_registry(ModelHandler) + _register_models(container) - handler: ModelHandler + _register_text_tokenizers(container) - configs: ConfigProvider[object] +def _register_models(container: DependencyContainer) -> None: # Blaser - configs = context.get_config_registry(BlaserConfig) - - default_arch = "basic_ref" - - handler = BlaserModelHandler( - configs, default_arch, asset_download_manager, tensor_loader + register_model_family( + container, + BLASER_FAMILY, + kls=BlaserModel, + config_kls=BlaserConfig, + factory=_create_blaser_model, + state_dict_converter=_convert_blaser_checkpoint, ) - registry.register(handler.family, handler) - - register_blaser_configs(context) + _register_blaser_configs(container) # Laser2 - configs = context.get_config_registry(Laser2Config) - - default_arch = "laser2" - - handler = Laser2ModelHandler( - configs, default_arch, asset_download_manager, tensor_loader + register_model_family( + container, + LASER2_FAMILY, + kls=LaserLstmEncoder, + config_kls=Laser2Config, + factory=_create_laser2_model, + state_dict_converter=_convert_laser2_checkpoint, ) - registry.register(handler.family, handler) - - register_laser2_configs(context) + _register_laser2_configs(container) # mutox - configs = context.get_config_registry(MutoxConfig) - default_arch = "mutox" - handler = MutoxModelHandler( - configs, default_arch, asset_download_manager, tensor_loader + register_model_family( + container, + MUTOX_FAMILY, + kls=MutoxClassifier, + config_kls=MutoxConfig, + factory=_create_mutox_model, + state_dict_converter=_convert_mutox_checkpoint, ) - registry.register(handler.family, handler) - register_mutox_configs(context) - # SONAR Speech Encoder - configs = context.get_config_registry(SonarSpeechEncoderConfig) - - default_arch = "english" + _register_mutox_configs(container) - handler = SonarSpeechEncoderHandler( - configs, default_arch, asset_download_manager, tensor_loader + # SONAR Speech Encoder + register_model_family( + container, + SONAR_SPEECH_FAMILY, + kls=SonarSpeechEncoderModel, + config_kls=SonarSpeechEncoderConfig, + factory=_create_sonar_speech_encoder_model, + state_dict_converter=_convert_sonar_speech_checkpoint, ) - registry.register(handler.family, handler) - - register_sonar_speech_encoder_configs(context) + _register_sonar_speech_encoder_configs(container) # SONAR Text Encoder - configs = context.get_config_registry(SonarTextEncoderConfig) - - default_arch = "basic" - - handler = SonarTextEncoderHandler( - configs, default_arch, asset_download_manager, tensor_loader + register_model_family( + container, + SONAR_TEXT_ENCODER_FAMILY, + kls=SonarTextTransformerEncoderModel, + config_kls=SonarTextEncoderConfig, + factory=_create_sonar_text_encoder_model, + state_dict_converter=_convert_sonar_text_encoder_checkpoint, ) - registry.register(handler.family, handler) - - register_sonar_text_encoder_configs(context) + _register_sonar_text_encoder_configs(container) # SONAR Text Decoder - configs = context.get_config_registry(SonarTextDecoderConfig) - - default_arch = "basic" - - handler = SonarTextDecoderHandler( - configs, default_arch, asset_download_manager, tensor_loader + register_model_family( + container, + SONAR_TEXT_DECODER_FAMILY, + kls=ConditionalTransformerDecoderModel, + config_kls=SonarTextDecoderConfig, + state_dict_converter=_convert_sonar_text_decoder_checkpoint, + factory=_create_sonar_text_decoder_model, ) - registry.register(handler.family, handler) - - register_sonar_text_decoder_configs(context) + _register_sonar_text_decoder_configs(container) -def _register_text_tokenizers(context: RuntimeContext) -> None: - registry = context.get_registry(TextTokenizerHandler) - +def _register_text_tokenizers(container: DependencyContainer) -> None: # Laser2 - handler = Laser2TokenizerHandler(context.asset_download_manager) - - registry.register(handler.family, handler) + register_tokenizer_family( + container, + LASER2_FAMILY, + kls=Laser2Tokenizer, + config_kls=NoneType, + loader=_load_laser2_tokenizer, + ) diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml index fef19aa..3f63fae 100644 --- a/sonar/cards/sonar_mutox.yaml +++ b/sonar/cards/sonar_mutox.yaml @@ -13,4 +13,4 @@ name: sonar_mutox model_family: mutox_classifier model_arch: mutox checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" -input_size: 1024 \ No newline at end of file +input_size: 1024 diff --git a/sonar/cards/sonar_speech_encoder.yaml b/sonar/cards/sonar_speech_encoder.yaml index bbeffc7..2e48c2b 100644 --- a/sonar/cards/sonar_speech_encoder.yaml +++ b/sonar/cards/sonar_speech_encoder.yaml @@ -13,45 +13,55 @@ model_arch: non_english name: sonar_speech_encoder_arb base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.arb.pt" -default_lang: arb -langs: - - arb +tokenizer_config: + _set_: + default_lang: arb + langs: + - arb --- name: sonar_speech_encoder_cat base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.cat.pt" -default_lang: cat -langs: - - cat +tokenizer_config: + _set_: + default_lang: cat + langs: + - cat --- name: sonar_speech_encoder_cym base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.cym.pt" -default_lang: cym -langs: - - cym +tokenizer_config: + _set_: + default_lang: cym + langs: + - cym --- name: sonar_speech_encoder_dan base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.dan.pt" -default_lang: dan -langs: - - dan +tokenizer_config: + _set_: + default_lang: dan + langs: + - dan --- name: sonar_speech_encoder_deu base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.deu.pt" -default_lang: deu -langs: - - deu +tokenizer_config: + _set_: + default_lang: deu + langs: + - deu --- @@ -59,465 +69,569 @@ name: sonar_speech_encoder_eng base: sonar_speech_encoder_base model_arch: english checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.eng.pt" -default_lang: eng -langs: - - eng +tokenizer_config: + _set_: + default_lang: eng + langs: + - eng --- name: sonar_speech_encoder_est base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.est.pt" -default_lang: est -langs: - - est +tokenizer_config: + _set_: + default_lang: est + langs: + - est --- name: sonar_speech_encoder_fin base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.fin.pt" -default_lang: fin -langs: - - fin +tokenizer_config: + _set_: + default_lang: fin + langs: + - fin --- name: sonar_speech_encoder_fra base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.fra.pt" -default_lang: fra -langs: - - fra +tokenizer_config: + _set_: + default_lang: fra + langs: + - fra --- name: sonar_speech_encoder_ind base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.ind.pt" -default_lang: ind -langs: - - ind +tokenizer_config: + _set_: + default_lang: ind + langs: + - ind --- name: sonar_speech_encoder_ita base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.ita.pt" -default_lang: ita -langs: - - ita +tokenizer_config: + _set_: + default_lang: ita + langs: + - ita --- name: sonar_speech_encoder_kor base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.kor.pt" -default_lang: kor -langs: - - kor +tokenizer_config: + _set_: + default_lang: kor + langs: + - kor --- name: sonar_speech_encoder_nld base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.nld.pt" -default_lang: nld -langs: - - nld +tokenizer_config: + _set_: + default_lang: nld + langs: + - nld --- name: sonar_speech_encoder_pes base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.pes.pt" -default_lang: pes -langs: - - pes +tokenizer_config: + _set_: + default_lang: pes + langs: + - pes --- name: sonar_speech_encoder_por base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.por.pt" -default_lang: por -langs: - - por +tokenizer_config: + _set_: + default_lang: por + langs: + - por --- name: sonar_speech_encoder_ron base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.ron.pt" -default_lang: ron -langs: - - ron +tokenizer_config: + _set_: + default_lang: ron + langs: + - ron --- name: sonar_speech_encoder_spa base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.spa.pt" -default_lang: spa -langs: - - spa +tokenizer_config: + _set_: + default_lang: spa + langs: + - spa --- name: sonar_speech_encoder_swh base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.swh.pt" -default_lang: swh -langs: - - swh +tokenizer_config: + _set_: + default_lang: swh + langs: + - swh --- name: sonar_speech_encoder_tgl base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.tgl.pt" -default_lang: tgl -langs: - - tgl +tokenizer_config: + _set_: + default_lang: tgl + langs: + - tgl --- name: sonar_speech_encoder_tur base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.tur.pt" -default_lang: tur -langs: - - tur +tokenizer_config: + _set_: + default_lang: tur + langs: + - tur --- name: sonar_speech_encoder_uzn base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v3ap.uzn.pt" -default_lang: uzn -langs: - - uzn +tokenizer_config: + _set_: + default_lang: uzn + langs: + - uzn --- name: sonar_speech_encoder_asm base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.asm.pt" -default_lang: asm -langs: - - asm +tokenizer_config: + _set_: + default_lang: asm + langs: + - asm --- name: sonar_speech_encoder_bel base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.bel.pt" -default_lang: bel -langs: - - bel +tokenizer_config: + _set_: + default_lang: bel + langs: + - bel --- name: sonar_speech_encoder_ben base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.ben.pt" -default_lang: ben -langs: - - ben +tokenizer_config: + _set_: + default_lang: ben + langs: + - ben --- name: sonar_speech_encoder_bos base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.bos.pt" -default_lang: bos -langs: - - bos +tokenizer_config: + _set_: + default_lang: bos + langs: + - bos --- name: sonar_speech_encoder_bul base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.bul.pt" -default_lang: bul -langs: - - bul +tokenizer_config: + _set_: + default_lang: bul + langs: + - bul --- name: sonar_speech_encoder_ces base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.ces.pt" -default_lang: ces -langs: - - ces +tokenizer_config: + _set_: + default_lang: ces + langs: + - ces --- name: sonar_speech_encoder_cmn base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.cmn.pt" -default_lang: cmn -langs: - - cmn +tokenizer_config: + _set_: + default_lang: cmn + langs: + - cmn --- name: sonar_speech_encoder_guj base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.guj.pt" -default_lang: guj -langs: - - guj +tokenizer_config: + _set_: + default_lang: guj + langs: + - guj --- name: sonar_speech_encoder_heb base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.heb.pt" -default_lang: heb -langs: - - heb +tokenizer_config: + _set_: + default_lang: heb + langs: + - heb --- name: sonar_speech_encoder_hin base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.hin.pt" -default_lang: hin -langs: - - hin +tokenizer_config: + _set_: + default_lang: hin + langs: + - hin --- name: sonar_speech_encoder_hrv base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.hrv.pt" -default_lang: hrv -langs: - - hrv +tokenizer_config: + _set_: + default_lang: hrv + langs: + - hrv --- name: sonar_speech_encoder_jpn base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.jpn.pt" -default_lang: jpn -langs: - - jpn +tokenizer_config: + _set_: + default_lang: jpn + langs: + - jpn --- name: sonar_speech_encoder_kan base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.kan.pt" -default_lang: kan -langs: - - kan +tokenizer_config: + _set_: + default_lang: kan + langs: + - kan --- name: sonar_speech_encoder_lao base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.lao.pt" -default_lang: lao -langs: - - lao +tokenizer_config: + _set_: + default_lang: lao + langs: + - lao --- name: sonar_speech_encoder_lit base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.lit.pt" -default_lang: lit -langs: - - lit +tokenizer_config: + _set_: + default_lang: lit + langs: + - lit --- name: sonar_speech_encoder_lvs base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.lvs.pt" -default_lang: lvs -langs: - - lvs +tokenizer_config: + _set_: + default_lang: lvs + langs: + - lvs --- name: sonar_speech_encoder_mal base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.mal.pt" -default_lang: mal -langs: - - mal +tokenizer_config: + _set_: + default_lang: mal + langs: + - mal --- name: sonar_speech_encoder_mar base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.mar.pt" -default_lang: mar -langs: - - mar +tokenizer_config: + _set_: + default_lang: mar + langs: + - mar --- name: sonar_speech_encoder_mkd base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.mkd.pt" -default_lang: mkd -langs: - - mkd +tokenizer_config: + _set_: + default_lang: mkd + langs: + - mkd --- name: sonar_speech_encoder_mlt base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.mlt.pt" -default_lang: mlt -langs: - - mlt +tokenizer_config: + _set_: + default_lang: mlt + langs: + - mlt --- name: sonar_speech_encoder_npi base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.npi.pt" -default_lang: npi -langs: - - npi +tokenizer_config: + _set_: + default_lang: npi + langs: + - npi --- name: sonar_speech_encoder_ory base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.ory.pt" -default_lang: ory -langs: - - ory +tokenizer_config: + _set_: + default_lang: ory + langs: + - ory --- name: sonar_speech_encoder_pan base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.pan.pt" -default_lang: pan -langs: - - pan +tokenizer_config: + _set_: + default_lang: pan + langs: + - pan --- name: sonar_speech_encoder_pol base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.pol.pt" -default_lang: pol -langs: - - pol +tokenizer_config: + _set_: + default_lang: pol + langs: + - pol --- name: sonar_speech_encoder_rus base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.rus.pt" -default_lang: rus -langs: - - rus +tokenizer_config: + _set_: + default_lang: rus + langs: + - rus --- name: sonar_speech_encoder_slk base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.slk.pt" -default_lang: slk -langs: - - slk +tokenizer_config: + _set_: + default_lang: slk + langs: + - slk --- name: sonar_speech_encoder_slv base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.slv.pt" -default_lang: slv -langs: - - slv +tokenizer_config: + _set_: + default_lang: slv + langs: + - slv --- name: sonar_speech_encoder_snd base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.snd.pt" -default_lang: snd -langs: - - snd +tokenizer_config: + _set_: + default_lang: snd + langs: + - snd --- name: sonar_speech_encoder_srp base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.srp.pt" -default_lang: srp -langs: - - srp +tokenizer_config: + _set_: + default_lang: srp + langs: + - srp --- name: sonar_speech_encoder_tam base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.tam.pt" -default_lang: tam -langs: - - tam +tokenizer_config: + _set_: + default_lang: tam + langs: + - tam --- name: sonar_speech_encoder_tel base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.tel.pt" -default_lang: tel -langs: - - tel +tokenizer_config: + _set_: + default_lang: tel + langs: + - tel --- name: sonar_speech_encoder_tha base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.tha.pt" -default_lang: tha -langs: - - tha +tokenizer_config: + _set_: + default_lang: tha + langs: + - tha --- name: sonar_speech_encoder_ukr base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.ukr.pt" -default_lang: ukr -langs: - - ukr +tokenizer_config: + _set_: + default_lang: ukr + langs: + - ukr --- name: sonar_speech_encoder_urd base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.urd.pt" -default_lang: urd -langs: - - urd +tokenizer_config: + _set_: + default_lang: urd + langs: + - urd --- name: sonar_speech_encoder_vie base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.vie.pt" -default_lang: vie -langs: - - vie +tokenizer_config: + _set_: + default_lang: vie + langs: + - vie --- name: sonar_speech_encoder_yue base: sonar_speech_encoder_base checkpoint: "https://dl.fbaipublicfiles.com/SONAR/spenc.v5ap.yue.pt" -default_lang: yue -langs: - - yue +tokenzier_config: + _set_: + default_lang: yue + langs: + - yue diff --git a/sonar/cards/text_sonar_basic_decoder.yaml b/sonar/cards/text_sonar_basic_decoder.yaml index b90cbbb..78f3f5f 100644 --- a/sonar/cards/text_sonar_basic_decoder.yaml +++ b/sonar/cards/text_sonar_basic_decoder.yaml @@ -10,207 +10,209 @@ model_arch: basic checkpoint: "https://dl.fbaipublicfiles.com/SONAR/sonar_text_decoder.pt" tokenizer: "https://dl.fbaipublicfiles.com/SONAR/sentencepiece.source.256000.model" tokenizer_family: nllb -default_lang: eng_Latn -langs: - - ace_Arab - - ace_Latn - - acm_Arab - - acq_Arab - - aeb_Arab - - afr_Latn - - ajp_Arab - - aka_Latn - - amh_Ethi - - apc_Arab - - arb_Arab - - ars_Arab - - ary_Arab - - arz_Arab - - asm_Beng - - ast_Latn - - awa_Deva - - ayr_Latn - - azb_Arab - - azj_Latn - - bak_Cyrl - - bam_Latn - - ban_Latn - - bel_Cyrl - - bem_Latn - - ben_Beng - - bho_Deva - - bjn_Arab - - bjn_Latn - - bod_Tibt - - bos_Latn - - bug_Latn - - bul_Cyrl - - cat_Latn - - ceb_Latn - - ces_Latn - - cjk_Latn - - ckb_Arab - - crh_Latn - - cym_Latn - - dan_Latn - - deu_Latn - - dik_Latn - - dyu_Latn - - dzo_Tibt - - ell_Grek - - eng_Latn - - epo_Latn - - est_Latn - - eus_Latn - - ewe_Latn - - fao_Latn - - pes_Arab - - fij_Latn - - fin_Latn - - fon_Latn - - fra_Latn - - fur_Latn - - fuv_Latn - - gla_Latn - - gle_Latn - - glg_Latn - - grn_Latn - - guj_Gujr - - hat_Latn - - hau_Latn - - heb_Hebr - - hin_Deva - - hne_Deva - - hrv_Latn - - hun_Latn - - hye_Armn - - ibo_Latn - - ilo_Latn - - ind_Latn - - isl_Latn - - ita_Latn - - jav_Latn - - jpn_Jpan - - kab_Latn - - kac_Latn - - kam_Latn - - kan_Knda - - kas_Arab - - kas_Deva - - kat_Geor - - knc_Arab - - knc_Latn - - kaz_Cyrl - - kbp_Latn - - kea_Latn - - khm_Khmr - - kik_Latn - - kin_Latn - - kir_Cyrl - - kmb_Latn - - kon_Latn - - kor_Hang - - kmr_Latn - - lao_Laoo - - lvs_Latn - - lij_Latn - - lim_Latn - - lin_Latn - - lit_Latn - - lmo_Latn - - ltg_Latn - - ltz_Latn - - lua_Latn - - lug_Latn - - luo_Latn - - lus_Latn - - mag_Deva - - mai_Deva - - mal_Mlym - - mar_Deva - - min_Latn - - mkd_Cyrl - - plt_Latn - - mlt_Latn - - mni_Beng - - khk_Cyrl - - mos_Latn - - mri_Latn - - zsm_Latn - - mya_Mymr - - nld_Latn - - nno_Latn - - nob_Latn - - npi_Deva - - nso_Latn - - nus_Latn - - nya_Latn - - oci_Latn - - gaz_Latn - - ory_Orya - - pag_Latn - - pan_Guru - - pap_Latn - - pol_Latn - - por_Latn - - prs_Arab - - pbt_Arab - - quy_Latn - - ron_Latn - - run_Latn - - rus_Cyrl - - sag_Latn - - san_Deva - - sat_Beng - - scn_Latn - - shn_Mymr - - sin_Sinh - - slk_Latn - - slv_Latn - - smo_Latn - - sna_Latn - - snd_Arab - - som_Latn - - sot_Latn - - spa_Latn - - als_Latn - - srd_Latn - - srp_Cyrl - - ssw_Latn - - sun_Latn - - swe_Latn - - swh_Latn - - szl_Latn - - tam_Taml - - tat_Cyrl - - tel_Telu - - tgk_Cyrl - - tgl_Latn - - tha_Thai - - tir_Ethi - - taq_Latn - - taq_Tfng - - tpi_Latn - - tsn_Latn - - tso_Latn - - tuk_Latn - - tum_Latn - - tur_Latn - - twi_Latn - - tzm_Tfng - - uig_Arab - - ukr_Cyrl - - umb_Latn - - urd_Arab - - uzn_Latn - - vec_Latn - - vie_Latn - - war_Latn - - wol_Latn - - xho_Latn - - ydd_Hebr - - yor_Latn - - yue_Hant - - zho_Hans - - zho_Hant - - zul_Latn +tokenizer_config: + _set_: + default_lang: eng_Latn + langs: + - ace_Arab + - ace_Latn + - acm_Arab + - acq_Arab + - aeb_Arab + - afr_Latn + - ajp_Arab + - aka_Latn + - amh_Ethi + - apc_Arab + - arb_Arab + - ars_Arab + - ary_Arab + - arz_Arab + - asm_Beng + - ast_Latn + - awa_Deva + - ayr_Latn + - azb_Arab + - azj_Latn + - bak_Cyrl + - bam_Latn + - ban_Latn + - bel_Cyrl + - bem_Latn + - ben_Beng + - bho_Deva + - bjn_Arab + - bjn_Latn + - bod_Tibt + - bos_Latn + - bug_Latn + - bul_Cyrl + - cat_Latn + - ceb_Latn + - ces_Latn + - cjk_Latn + - ckb_Arab + - crh_Latn + - cym_Latn + - dan_Latn + - deu_Latn + - dik_Latn + - dyu_Latn + - dzo_Tibt + - ell_Grek + - eng_Latn + - epo_Latn + - est_Latn + - eus_Latn + - ewe_Latn + - fao_Latn + - pes_Arab + - fij_Latn + - fin_Latn + - fon_Latn + - fra_Latn + - fur_Latn + - fuv_Latn + - gla_Latn + - gle_Latn + - glg_Latn + - grn_Latn + - guj_Gujr + - hat_Latn + - hau_Latn + - heb_Hebr + - hin_Deva + - hne_Deva + - hrv_Latn + - hun_Latn + - hye_Armn + - ibo_Latn + - ilo_Latn + - ind_Latn + - isl_Latn + - ita_Latn + - jav_Latn + - jpn_Jpan + - kab_Latn + - kac_Latn + - kam_Latn + - kan_Knda + - kas_Arab + - kas_Deva + - kat_Geor + - knc_Arab + - knc_Latn + - kaz_Cyrl + - kbp_Latn + - kea_Latn + - khm_Khmr + - kik_Latn + - kin_Latn + - kir_Cyrl + - kmb_Latn + - kon_Latn + - kor_Hang + - kmr_Latn + - lao_Laoo + - lvs_Latn + - lij_Latn + - lim_Latn + - lin_Latn + - lit_Latn + - lmo_Latn + - ltg_Latn + - ltz_Latn + - lua_Latn + - lug_Latn + - luo_Latn + - lus_Latn + - mag_Deva + - mai_Deva + - mal_Mlym + - mar_Deva + - min_Latn + - mkd_Cyrl + - plt_Latn + - mlt_Latn + - mni_Beng + - khk_Cyrl + - mos_Latn + - mri_Latn + - zsm_Latn + - mya_Mymr + - nld_Latn + - nno_Latn + - nob_Latn + - npi_Deva + - nso_Latn + - nus_Latn + - nya_Latn + - oci_Latn + - gaz_Latn + - ory_Orya + - pag_Latn + - pan_Guru + - pap_Latn + - pol_Latn + - por_Latn + - prs_Arab + - pbt_Arab + - quy_Latn + - ron_Latn + - run_Latn + - rus_Cyrl + - sag_Latn + - san_Deva + - sat_Beng + - scn_Latn + - shn_Mymr + - sin_Sinh + - slk_Latn + - slv_Latn + - smo_Latn + - sna_Latn + - snd_Arab + - som_Latn + - sot_Latn + - spa_Latn + - als_Latn + - srd_Latn + - srp_Cyrl + - ssw_Latn + - sun_Latn + - swe_Latn + - swh_Latn + - szl_Latn + - tam_Taml + - tat_Cyrl + - tel_Telu + - tgk_Cyrl + - tgl_Latn + - tha_Thai + - tir_Ethi + - taq_Latn + - taq_Tfng + - tpi_Latn + - tsn_Latn + - tso_Latn + - tuk_Latn + - tum_Latn + - tur_Latn + - twi_Latn + - tzm_Tfng + - uig_Arab + - ukr_Cyrl + - umb_Latn + - urd_Arab + - uzn_Latn + - vec_Latn + - vie_Latn + - war_Latn + - wol_Latn + - xho_Latn + - ydd_Hebr + - yor_Latn + - yue_Hant + - zho_Hans + - zho_Hant + - zul_Latn diff --git a/sonar/cards/text_sonar_basic_encoder.yaml b/sonar/cards/text_sonar_basic_encoder.yaml index c58bd7b..1799941 100644 --- a/sonar/cards/text_sonar_basic_encoder.yaml +++ b/sonar/cards/text_sonar_basic_encoder.yaml @@ -10,207 +10,209 @@ model_arch: basic checkpoint: "https://dl.fbaipublicfiles.com/SONAR/sonar_text_encoder.pt" tokenizer: "https://dl.fbaipublicfiles.com/SONAR/sentencepiece.source.256000.model" tokenizer_family: nllb -default_lang: eng_Latn -langs: - - ace_Arab - - ace_Latn - - acm_Arab - - acq_Arab - - aeb_Arab - - afr_Latn - - ajp_Arab - - aka_Latn - - amh_Ethi - - apc_Arab - - arb_Arab - - ars_Arab - - ary_Arab - - arz_Arab - - asm_Beng - - ast_Latn - - awa_Deva - - ayr_Latn - - azb_Arab - - azj_Latn - - bak_Cyrl - - bam_Latn - - ban_Latn - - bel_Cyrl - - bem_Latn - - ben_Beng - - bho_Deva - - bjn_Arab - - bjn_Latn - - bod_Tibt - - bos_Latn - - bug_Latn - - bul_Cyrl - - cat_Latn - - ceb_Latn - - ces_Latn - - cjk_Latn - - ckb_Arab - - crh_Latn - - cym_Latn - - dan_Latn - - deu_Latn - - dik_Latn - - dyu_Latn - - dzo_Tibt - - ell_Grek - - eng_Latn - - epo_Latn - - est_Latn - - eus_Latn - - ewe_Latn - - fao_Latn - - pes_Arab - - fij_Latn - - fin_Latn - - fon_Latn - - fra_Latn - - fur_Latn - - fuv_Latn - - gla_Latn - - gle_Latn - - glg_Latn - - grn_Latn - - guj_Gujr - - hat_Latn - - hau_Latn - - heb_Hebr - - hin_Deva - - hne_Deva - - hrv_Latn - - hun_Latn - - hye_Armn - - ibo_Latn - - ilo_Latn - - ind_Latn - - isl_Latn - - ita_Latn - - jav_Latn - - jpn_Jpan - - kab_Latn - - kac_Latn - - kam_Latn - - kan_Knda - - kas_Arab - - kas_Deva - - kat_Geor - - knc_Arab - - knc_Latn - - kaz_Cyrl - - kbp_Latn - - kea_Latn - - khm_Khmr - - kik_Latn - - kin_Latn - - kir_Cyrl - - kmb_Latn - - kon_Latn - - kor_Hang - - kmr_Latn - - lao_Laoo - - lvs_Latn - - lij_Latn - - lim_Latn - - lin_Latn - - lit_Latn - - lmo_Latn - - ltg_Latn - - ltz_Latn - - lua_Latn - - lug_Latn - - luo_Latn - - lus_Latn - - mag_Deva - - mai_Deva - - mal_Mlym - - mar_Deva - - min_Latn - - mkd_Cyrl - - plt_Latn - - mlt_Latn - - mni_Beng - - khk_Cyrl - - mos_Latn - - mri_Latn - - zsm_Latn - - mya_Mymr - - nld_Latn - - nno_Latn - - nob_Latn - - npi_Deva - - nso_Latn - - nus_Latn - - nya_Latn - - oci_Latn - - gaz_Latn - - ory_Orya - - pag_Latn - - pan_Guru - - pap_Latn - - pol_Latn - - por_Latn - - prs_Arab - - pbt_Arab - - quy_Latn - - ron_Latn - - run_Latn - - rus_Cyrl - - sag_Latn - - san_Deva - - sat_Beng - - scn_Latn - - shn_Mymr - - sin_Sinh - - slk_Latn - - slv_Latn - - smo_Latn - - sna_Latn - - snd_Arab - - som_Latn - - sot_Latn - - spa_Latn - - als_Latn - - srd_Latn - - srp_Cyrl - - ssw_Latn - - sun_Latn - - swe_Latn - - swh_Latn - - szl_Latn - - tam_Taml - - tat_Cyrl - - tel_Telu - - tgk_Cyrl - - tgl_Latn - - tha_Thai - - tir_Ethi - - taq_Latn - - taq_Tfng - - tpi_Latn - - tsn_Latn - - tso_Latn - - tuk_Latn - - tum_Latn - - tur_Latn - - twi_Latn - - tzm_Tfng - - uig_Arab - - ukr_Cyrl - - umb_Latn - - urd_Arab - - uzn_Latn - - vec_Latn - - vie_Latn - - war_Latn - - wol_Latn - - xho_Latn - - ydd_Hebr - - yor_Latn - - yue_Hant - - zho_Hans - - zho_Hant - - zul_Latn +tokenizer_config: + _set_: + default_lang: eng_Latn + langs: + - ace_Arab + - ace_Latn + - acm_Arab + - acq_Arab + - aeb_Arab + - afr_Latn + - ajp_Arab + - aka_Latn + - amh_Ethi + - apc_Arab + - arb_Arab + - ars_Arab + - ary_Arab + - arz_Arab + - asm_Beng + - ast_Latn + - awa_Deva + - ayr_Latn + - azb_Arab + - azj_Latn + - bak_Cyrl + - bam_Latn + - ban_Latn + - bel_Cyrl + - bem_Latn + - ben_Beng + - bho_Deva + - bjn_Arab + - bjn_Latn + - bod_Tibt + - bos_Latn + - bug_Latn + - bul_Cyrl + - cat_Latn + - ceb_Latn + - ces_Latn + - cjk_Latn + - ckb_Arab + - crh_Latn + - cym_Latn + - dan_Latn + - deu_Latn + - dik_Latn + - dyu_Latn + - dzo_Tibt + - ell_Grek + - eng_Latn + - epo_Latn + - est_Latn + - eus_Latn + - ewe_Latn + - fao_Latn + - pes_Arab + - fij_Latn + - fin_Latn + - fon_Latn + - fra_Latn + - fur_Latn + - fuv_Latn + - gla_Latn + - gle_Latn + - glg_Latn + - grn_Latn + - guj_Gujr + - hat_Latn + - hau_Latn + - heb_Hebr + - hin_Deva + - hne_Deva + - hrv_Latn + - hun_Latn + - hye_Armn + - ibo_Latn + - ilo_Latn + - ind_Latn + - isl_Latn + - ita_Latn + - jav_Latn + - jpn_Jpan + - kab_Latn + - kac_Latn + - kam_Latn + - kan_Knda + - kas_Arab + - kas_Deva + - kat_Geor + - knc_Arab + - knc_Latn + - kaz_Cyrl + - kbp_Latn + - kea_Latn + - khm_Khmr + - kik_Latn + - kin_Latn + - kir_Cyrl + - kmb_Latn + - kon_Latn + - kor_Hang + - kmr_Latn + - lao_Laoo + - lvs_Latn + - lij_Latn + - lim_Latn + - lin_Latn + - lit_Latn + - lmo_Latn + - ltg_Latn + - ltz_Latn + - lua_Latn + - lug_Latn + - luo_Latn + - lus_Latn + - mag_Deva + - mai_Deva + - mal_Mlym + - mar_Deva + - min_Latn + - mkd_Cyrl + - plt_Latn + - mlt_Latn + - mni_Beng + - khk_Cyrl + - mos_Latn + - mri_Latn + - zsm_Latn + - mya_Mymr + - nld_Latn + - nno_Latn + - nob_Latn + - npi_Deva + - nso_Latn + - nus_Latn + - nya_Latn + - oci_Latn + - gaz_Latn + - ory_Orya + - pag_Latn + - pan_Guru + - pap_Latn + - pol_Latn + - por_Latn + - prs_Arab + - pbt_Arab + - quy_Latn + - ron_Latn + - run_Latn + - rus_Cyrl + - sag_Latn + - san_Deva + - sat_Beng + - scn_Latn + - shn_Mymr + - sin_Sinh + - slk_Latn + - slv_Latn + - smo_Latn + - sna_Latn + - snd_Arab + - som_Latn + - sot_Latn + - spa_Latn + - als_Latn + - srd_Latn + - srp_Cyrl + - ssw_Latn + - sun_Latn + - swe_Latn + - swh_Latn + - szl_Latn + - tam_Taml + - tat_Cyrl + - tel_Telu + - tgk_Cyrl + - tgl_Latn + - tha_Thai + - tir_Ethi + - taq_Latn + - taq_Tfng + - tpi_Latn + - tsn_Latn + - tso_Latn + - tuk_Latn + - tum_Latn + - tur_Latn + - twi_Latn + - tzm_Tfng + - uig_Arab + - ukr_Cyrl + - umb_Latn + - urd_Arab + - uzn_Latn + - vec_Latn + - vie_Latn + - war_Latn + - wol_Latn + - xho_Latn + - ydd_Hebr + - yor_Latn + - yue_Hant + - zho_Hans + - zho_Hant + - zul_Latn diff --git a/sonar/cards/text_sonar_finetuned_decoder.yaml b/sonar/cards/text_sonar_finetuned_decoder.yaml index bd738b7..2b491a6 100644 --- a/sonar/cards/text_sonar_finetuned_decoder.yaml +++ b/sonar/cards/text_sonar_finetuned_decoder.yaml @@ -11,207 +11,209 @@ model_arch: basic checkpoint: "https://dl.fbaipublicfiles.com/SONAR/finetuned_decoder.pt" tokenizer: "https://dl.fbaipublicfiles.com/SONAR/sentencepiece.source.256000.model" tokenizer_family: nllb -default_lang: eng_Latn -langs: - - ace_Arab - - ace_Latn - - acm_Arab - - acq_Arab - - aeb_Arab - - afr_Latn - - ajp_Arab - - aka_Latn - - amh_Ethi - - apc_Arab - - arb_Arab - - ars_Arab - - ary_Arab - - arz_Arab - - asm_Beng - - ast_Latn - - awa_Deva - - ayr_Latn - - azb_Arab - - azj_Latn - - bak_Cyrl - - bam_Latn - - ban_Latn - - bel_Cyrl - - bem_Latn - - ben_Beng - - bho_Deva - - bjn_Arab - - bjn_Latn - - bod_Tibt - - bos_Latn - - bug_Latn - - bul_Cyrl - - cat_Latn - - ceb_Latn - - ces_Latn - - cjk_Latn - - ckb_Arab - - crh_Latn - - cym_Latn - - dan_Latn - - deu_Latn - - dik_Latn - - dyu_Latn - - dzo_Tibt - - ell_Grek - - eng_Latn - - epo_Latn - - est_Latn - - eus_Latn - - ewe_Latn - - fao_Latn - - pes_Arab - - fij_Latn - - fin_Latn - - fon_Latn - - fra_Latn - - fur_Latn - - fuv_Latn - - gla_Latn - - gle_Latn - - glg_Latn - - grn_Latn - - guj_Gujr - - hat_Latn - - hau_Latn - - heb_Hebr - - hin_Deva - - hne_Deva - - hrv_Latn - - hun_Latn - - hye_Armn - - ibo_Latn - - ilo_Latn - - ind_Latn - - isl_Latn - - ita_Latn - - jav_Latn - - jpn_Jpan - - kab_Latn - - kac_Latn - - kam_Latn - - kan_Knda - - kas_Arab - - kas_Deva - - kat_Geor - - knc_Arab - - knc_Latn - - kaz_Cyrl - - kbp_Latn - - kea_Latn - - khm_Khmr - - kik_Latn - - kin_Latn - - kir_Cyrl - - kmb_Latn - - kon_Latn - - kor_Hang - - kmr_Latn - - lao_Laoo - - lvs_Latn - - lij_Latn - - lim_Latn - - lin_Latn - - lit_Latn - - lmo_Latn - - ltg_Latn - - ltz_Latn - - lua_Latn - - lug_Latn - - luo_Latn - - lus_Latn - - mag_Deva - - mai_Deva - - mal_Mlym - - mar_Deva - - min_Latn - - mkd_Cyrl - - plt_Latn - - mlt_Latn - - mni_Beng - - khk_Cyrl - - mos_Latn - - mri_Latn - - zsm_Latn - - mya_Mymr - - nld_Latn - - nno_Latn - - nob_Latn - - npi_Deva - - nso_Latn - - nus_Latn - - nya_Latn - - oci_Latn - - gaz_Latn - - ory_Orya - - pag_Latn - - pan_Guru - - pap_Latn - - pol_Latn - - por_Latn - - prs_Arab - - pbt_Arab - - quy_Latn - - ron_Latn - - run_Latn - - rus_Cyrl - - sag_Latn - - san_Deva - - sat_Beng - - scn_Latn - - shn_Mymr - - sin_Sinh - - slk_Latn - - slv_Latn - - smo_Latn - - sna_Latn - - snd_Arab - - som_Latn - - sot_Latn - - spa_Latn - - als_Latn - - srd_Latn - - srp_Cyrl - - ssw_Latn - - sun_Latn - - swe_Latn - - swh_Latn - - szl_Latn - - tam_Taml - - tat_Cyrl - - tel_Telu - - tgk_Cyrl - - tgl_Latn - - tha_Thai - - tir_Ethi - - taq_Latn - - taq_Tfng - - tpi_Latn - - tsn_Latn - - tso_Latn - - tuk_Latn - - tum_Latn - - tur_Latn - - twi_Latn - - tzm_Tfng - - uig_Arab - - ukr_Cyrl - - umb_Latn - - urd_Arab - - uzn_Latn - - vec_Latn - - vie_Latn - - war_Latn - - wol_Latn - - xho_Latn - - ydd_Hebr - - yor_Latn - - yue_Hant - - zho_Hans - - zho_Hant - - zul_Latn +tokenizer_config: + _set_: + default_lang: eng_Latn + langs: + - ace_Arab + - ace_Latn + - acm_Arab + - acq_Arab + - aeb_Arab + - afr_Latn + - ajp_Arab + - aka_Latn + - amh_Ethi + - apc_Arab + - arb_Arab + - ars_Arab + - ary_Arab + - arz_Arab + - asm_Beng + - ast_Latn + - awa_Deva + - ayr_Latn + - azb_Arab + - azj_Latn + - bak_Cyrl + - bam_Latn + - ban_Latn + - bel_Cyrl + - bem_Latn + - ben_Beng + - bho_Deva + - bjn_Arab + - bjn_Latn + - bod_Tibt + - bos_Latn + - bug_Latn + - bul_Cyrl + - cat_Latn + - ceb_Latn + - ces_Latn + - cjk_Latn + - ckb_Arab + - crh_Latn + - cym_Latn + - dan_Latn + - deu_Latn + - dik_Latn + - dyu_Latn + - dzo_Tibt + - ell_Grek + - eng_Latn + - epo_Latn + - est_Latn + - eus_Latn + - ewe_Latn + - fao_Latn + - pes_Arab + - fij_Latn + - fin_Latn + - fon_Latn + - fra_Latn + - fur_Latn + - fuv_Latn + - gla_Latn + - gle_Latn + - glg_Latn + - grn_Latn + - guj_Gujr + - hat_Latn + - hau_Latn + - heb_Hebr + - hin_Deva + - hne_Deva + - hrv_Latn + - hun_Latn + - hye_Armn + - ibo_Latn + - ilo_Latn + - ind_Latn + - isl_Latn + - ita_Latn + - jav_Latn + - jpn_Jpan + - kab_Latn + - kac_Latn + - kam_Latn + - kan_Knda + - kas_Arab + - kas_Deva + - kat_Geor + - knc_Arab + - knc_Latn + - kaz_Cyrl + - kbp_Latn + - kea_Latn + - khm_Khmr + - kik_Latn + - kin_Latn + - kir_Cyrl + - kmb_Latn + - kon_Latn + - kor_Hang + - kmr_Latn + - lao_Laoo + - lvs_Latn + - lij_Latn + - lim_Latn + - lin_Latn + - lit_Latn + - lmo_Latn + - ltg_Latn + - ltz_Latn + - lua_Latn + - lug_Latn + - luo_Latn + - lus_Latn + - mag_Deva + - mai_Deva + - mal_Mlym + - mar_Deva + - min_Latn + - mkd_Cyrl + - plt_Latn + - mlt_Latn + - mni_Beng + - khk_Cyrl + - mos_Latn + - mri_Latn + - zsm_Latn + - mya_Mymr + - nld_Latn + - nno_Latn + - nob_Latn + - npi_Deva + - nso_Latn + - nus_Latn + - nya_Latn + - oci_Latn + - gaz_Latn + - ory_Orya + - pag_Latn + - pan_Guru + - pap_Latn + - pol_Latn + - por_Latn + - prs_Arab + - pbt_Arab + - quy_Latn + - ron_Latn + - run_Latn + - rus_Cyrl + - sag_Latn + - san_Deva + - sat_Beng + - scn_Latn + - shn_Mymr + - sin_Sinh + - slk_Latn + - slv_Latn + - smo_Latn + - sna_Latn + - snd_Arab + - som_Latn + - sot_Latn + - spa_Latn + - als_Latn + - srd_Latn + - srp_Cyrl + - ssw_Latn + - sun_Latn + - swe_Latn + - swh_Latn + - szl_Latn + - tam_Taml + - tat_Cyrl + - tel_Telu + - tgk_Cyrl + - tgl_Latn + - tha_Thai + - tir_Ethi + - taq_Latn + - taq_Tfng + - tpi_Latn + - tsn_Latn + - tso_Latn + - tuk_Latn + - tum_Latn + - tur_Latn + - twi_Latn + - tzm_Tfng + - uig_Arab + - ukr_Cyrl + - umb_Latn + - urd_Arab + - uzn_Latn + - vec_Latn + - vie_Latn + - war_Latn + - wol_Latn + - xho_Latn + - ydd_Hebr + - yor_Latn + - yue_Hant + - zho_Hans + - zho_Hant + - zul_Latn diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index 6ab2f1b..e8efc22 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -7,8 +7,8 @@ from typing import Union import torch -from fairseq2.data import DataPipelineBuilder -from fairseq2.typing import Device +from fairseq2.data.data_pipeline import DataPipelineBuilder +from fairseq2.device import Device from sonar.inference_pipelines.speech import ( AudioToFbankDataPipelineBuilder, @@ -44,7 +44,7 @@ def __init__( self.model.to(device).eval() if isinstance(mutox_classifier, str): - self.mutox_classifier = get_mutox_model_hub().load( + self.mutox_classifier = get_mutox_model_hub().load_model( mutox_classifier, device=device, ) @@ -61,8 +61,8 @@ def load_model_from_name( device: Device = CPU_DEVICE, ) -> "MutoxSpeechClassifierPipeline": encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load(encoder_name, device=device) - mutox_classifier = get_mutox_model_hub().load( + encoder = encoder_hub.load_model(encoder_name, device=device) + mutox_classifier = get_mutox_model_hub().load_model( mutox_classifier_name, device=device, ) diff --git a/sonar/inference_pipelines/speech.py b/sonar/inference_pipelines/speech.py index 4e8e41f..9460337 100644 --- a/sonar/inference_pipelines/speech.py +++ b/sonar/inference_pipelines/speech.py @@ -8,25 +8,24 @@ from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import Iterable, List, Optional, Sequence, Union, cast +from typing import Iterable, List, Optional, Sequence, Union -import fairseq2 import torch -from fairseq2.data import ( +from fairseq2.data._memory import MemoryBlock +from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter +from fairseq2.data.data_pipeline import ( Collater, DataPipeline, DataPipelineBuilder, FileMapper, - MemoryBlock, read_sequence, ) -from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter from fairseq2.data.text import StrSplitter, read_text -from fairseq2.data.text.tokenizers import TextTokenizer, get_text_tokenizer_hub -from fairseq2.generation import BeamSearchSeq2SeqGenerator +from fairseq2.data.tokenizers import Tokenizer, load_tokenizer +from fairseq2.data_type import DataType +from fairseq2.device import Device +from fairseq2.generation.beam_search.generator import BeamSearchSeq2SeqGenerator from fairseq2.generation.text import SequenceToTextConverter -from fairseq2.models.sequence import SequenceBatch -from fairseq2.typing import DataType, Device from sonar.inference_pipelines.utils import add_progress_bar, extract_sequence_batch from sonar.models.encoder_model import SonarEncoderModel @@ -179,7 +178,7 @@ def __init__(self, model: SonarSpeechEncoderModel) -> None: @classmethod def load_model_from_name(cls, encoder_name: str) -> "SpeechToEmbeddingPipeline": encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load(encoder_name, device=CPU_DEVICE) + encoder = encoder_hub.load_model(encoder_name, device=CPU_DEVICE) return cls(model=encoder) def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder: @@ -196,7 +195,8 @@ def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuild @torch.inference_mode() def run_inference(self, data: dict) -> dict: # TODO assert all(data['sample_rate'] == 16000.0) - return self.model(data["fbank"]) + # Note the fan-out which unpacks the Tuple[Tensor, BatchLayout] + return self.model(*data["fbank"]) class SpeechToTextPipeline(SpeechInferencePipeline): @@ -225,11 +225,9 @@ class SpeechToTextPipeline(SpeechInferencePipeline): AudioToFbankDataPipelineBuilder() ) model: SonarEncoderDecoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer - def __init__( - self, model: SonarEncoderDecoderModel, tokenizer: TextTokenizer - ) -> None: + def __init__(self, model: SonarEncoderDecoderModel, tokenizer: Tokenizer) -> None: self.model = model.eval() self.tokenizer = tokenizer @@ -237,21 +235,22 @@ def __init__( def load_model_from_name( cls, encoder_name: str, decoder_name: str ) -> "SpeechToTextPipeline": - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load(decoder_name) + tokenizer = load_tokenizer(decoder_name) encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load(encoder_name, device=CPU_DEVICE) + encoder = encoder_hub.load_model(encoder_name, device=CPU_DEVICE) decoder_hub = get_sonar_text_decoder_hub() - decoder = decoder_hub.load(decoder_name, device=CPU_DEVICE) + decoder = decoder_hub.load_model(decoder_name, device=CPU_DEVICE) model = SonarEncoderDecoderModel(encoder, decoder).eval() return cls(model=model, tokenizer=tokenizer) def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder: assert context.target_lang is not None - generator = BeamSearchSeq2SeqGenerator(self.model.to(context.device)) + generator = BeamSearchSeq2SeqGenerator( + self.model.to(context.device), self.tokenizer.vocab_info + ) converter = SequenceToTextConverter( generator, self.tokenizer, @@ -260,8 +259,8 @@ def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuild ) def _do_generate(data: dict) -> List[str]: - batch = cast(SequenceBatch, data["fbank"]) - texts, _ = converter.batch_convert(batch.seqs, batch.padding_mask) + seqs, seqs_layout = data["fbank"] + texts, _ = converter.batch_convert(seqs, seqs_layout) return texts return ( @@ -279,7 +278,6 @@ class SpeechModelPipelineInterface(torch.nn.Module): def __init__(self, fbank_dtype: DataType) -> None: super().__init__() - fairseq2.setup_fairseq2() self.convert_to_fbank = WaveformToFbankConverter( num_mel_bins=80, waveform_scale=2**15, @@ -310,13 +308,13 @@ def _decode_audio(self, inp: Union[str, torch.Tensor]) -> dict: class SpeechToTextModelPipeline(SpeechModelPipelineInterface): model: SonarEncoderDecoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer def __init__( self, encoder: Union[str, SonarEncoderModel], decoder: Union[str, ConditionalTransformerDecoderModel], - tokenizer: Union[str, TextTokenizer], + tokenizer: Union[str, Tokenizer], device: Device = CPU_DEVICE, fbank_dtype: DataType = torch.float32, ) -> None: @@ -324,7 +322,7 @@ def __init__( Args: encoder (Union[str, SonarEncoderModel]): either cart name or model object decoder (Union[str, ConditionalTransformerDecoderModel]): either cart name or model object - tokenizer (Union[str, TextTokenizer]): either cart name or tokenizer object + tokenizer (Union[str, Tokenizer]): either cart name or tokenizer object device (device, optional): . Defaults to CPU_DEVICE. fbank_dtype (DataType, optional):. Defaults to torch.float32. """ @@ -332,16 +330,15 @@ def __init__( super().__init__(fbank_dtype) if isinstance(encoder, str): encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load(encoder, device=device) + encoder = encoder_hub.load_model(encoder, device=device) if isinstance(decoder, str): decoder_hub = get_sonar_text_decoder_hub() - decoder = decoder_hub.load(decoder, device=device) + decoder = decoder_hub.load_model(decoder, device=device) if isinstance(tokenizer, str): - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load(tokenizer) + tokenizer = load_tokenizer(tokenizer) self.tokenizer = tokenizer - self.model = SonarEncoderDecoderModel(encoder, decoder).to(device).eval() + self.model = SonarEncoderDecoderModel(encoder, decoder).to(device).eval() # type: ignore # Only quantize the model in CUDA to bypass the error "LayerNormKernelImpl" not implemented for 'Half' # in some CUDAs and torch versions @@ -361,7 +358,7 @@ def predict( **generator_kwargs, ) -> List[str]: generator = BeamSearchSeq2SeqGenerator( - self.model.to(self.device), **generator_kwargs + self.model.to(self.device), self.tokenizer.vocab_info, **generator_kwargs ) converter = SequenceToTextConverter( generator, @@ -371,8 +368,8 @@ def predict( ) def _do_generate(data: dict) -> List[str]: - batch = cast(SequenceBatch, data["fbank"]) - texts, _ = converter.batch_convert(batch.seqs, batch.padding_mask) + seqs, seqs_layout = data["fbank"] + texts, _ = converter.batch_convert(seqs, seqs_layout) return texts pipeline: Iterable = ( @@ -401,7 +398,7 @@ def _do_generate(data: dict) -> List[str]: class SpeechToEmbeddingModelPipeline(SpeechModelPipelineInterface): model: SonarEncoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer def __init__( self, @@ -420,8 +417,8 @@ def __init__( if isinstance(encoder, str): encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load(encoder, device=device) - self.model = encoder.to(device).eval() + encoder = encoder_hub.load_model(encoder, device=device) + self.model = encoder.to(device).eval() # type: ignore # Only quantize the model in CUDA to bypass the error "LayerNormKernelImpl" not implemented for 'Half' # in some CUDAs and torch versions @@ -449,7 +446,7 @@ def build_predict_pipeline( lambda fbank: extract_sequence_batch(fbank, self.device), selector="fbank", ) - .map(lambda data: self.model(data["fbank"]).sentence_embeddings) + .map(lambda data: self.model(*data["fbank"]).sentence_embeddings) ) return pipeline diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 34ab79b..0bb5124 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -8,26 +8,25 @@ from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Union, cast -import fairseq2 import torch -from fairseq2.data import Collater, read_sequence +from fairseq2.data.data_pipeline import Collater, read_sequence from fairseq2.data.text import read_text -from fairseq2.data.text.tokenizers import TextTokenizer, get_text_tokenizer_hub -from fairseq2.generation import ( - BeamSearchSeq2SeqGenerator, - Sampler, - SamplingSeq2SeqGenerator, - Seq2SeqGenerator, +from fairseq2.data.tokenizers import Tokenizer, load_tokenizer +from fairseq2.data_type import DataType +from fairseq2.device import CPU, Device +from fairseq2.generation import Seq2SeqGenerator +from fairseq2.generation.beam_search.generator import BeamSearchSeq2SeqGenerator +from fairseq2.generation.sampling import Sampler, SamplingSeq2SeqGenerator +from fairseq2.models import load_model +from fairseq2.nn import BatchLayout + +from sonar.inference_pipelines.utils import ( + SequenceToTextConverter, + TextTranslator, + add_progress_bar, + extract_sequence_batch, ) -from fairseq2.generation.text import SequenceToTextConverter, TextTranslator -from fairseq2.typing import CPU, DataType, Device - -from sonar.inference_pipelines.utils import add_progress_bar, extract_sequence_batch from sonar.models.encoder_model import SonarEncoderModel -from sonar.models.sonar_text import ( - get_sonar_text_decoder_hub, - get_sonar_text_encoder_hub, -) from sonar.models.sonar_translation import SonarEncoderDecoderModel from sonar.models.sonar_translation.model import DummyEncoderModel from sonar.nn.conditional_decoder_model import ConditionalTransformerDecoderModel @@ -56,13 +55,17 @@ def __exit__(self, exc_type, exc_value, traceback): class TextToTextModelPipeline(torch.nn.Module): model: SonarEncoderDecoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer def __init__( self, encoder: Union[str, SonarEncoderModel], decoder: Union[str, ConditionalTransformerDecoderModel], - tokenizer: Union[str, TextTokenizer], + tokenizer: Optional[ + Union[str, Tokenizer] + ] = None, # did not remove this to avoid breaking existing code + encoder_tokenizer: Optional[Union[str, Tokenizer]] = None, + decoder_tokenizer: Optional[Union[str, Tokenizer]] = None, device: Device = CPU, dtype: Optional[DataType] = None, ) -> None: @@ -70,23 +73,40 @@ def __init__( Args: encoder (Union[str, SonarEncoderModel]): either card name or model object decoder (Union[str, ConditionalTransformerDecoderModel]): either card name or model object - tokenizer (Union[str, TextTokenizer]): either card name or tokenizer object + tokenizer (Union[str, Tokenizer], optional): either card name or tokenizer object. Defaults to None. + encoder_tokenizer (Union[str, Tokenizer], optional): either card name or tokenizer object. Defaults to None. + decoder_tokenizer (Union[str, Tokenizer], optional): either card name or tokenizer object. Defaults to None. device (Device, optional): Defaults to CPU. dtype (DataType, optional): The data type of the model parameters and buffers. """ super().__init__() - fairseq2.setup_fairseq2() if isinstance(encoder, str): - encoder_hub = get_sonar_text_encoder_hub() - encoder = encoder_hub.load(encoder, device=device, dtype=dtype) + encoder = load_model(encoder, device=device, dtype=dtype) # type: ignore if isinstance(decoder, str): - decoder_hub = get_sonar_text_decoder_hub() - decoder = decoder_hub.load(decoder, device=device, dtype=dtype) + decoder = load_model(decoder, device=device, dtype=dtype) # type: ignore if isinstance(tokenizer, str): - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load(tokenizer) + tokenizer = load_tokenizer(tokenizer) + + assert tokenizer is not None or ( + encoder_tokenizer is not None and decoder_tokenizer is not None + ), "Either tokenizer or both encoder_tokenizer and decoder_tokenizer must be provided" + + if tokenizer is not None: + if isinstance(tokenizer, str): + tokenizer = load_tokenizer(tokenizer) + self.encoder_tokenizer = tokenizer + self.decoder_tokenizer = tokenizer + else: + if isinstance(encoder_tokenizer, str): + encoder_tokenizer = load_tokenizer(encoder_tokenizer) + if isinstance(decoder_tokenizer, str): + decoder_tokenizer = load_tokenizer(decoder_tokenizer) + assert ( + encoder_tokenizer is not None and decoder_tokenizer is not None + ), "we need both encoder_tokenizer and decoder_tokenizer if tokenizer is not provided" # noqa + self.encoder_tokenizer = encoder_tokenizer + self.decoder_tokenizer = decoder_tokenizer - self.tokenizer = tokenizer self.model = SonarEncoderDecoderModel(encoder, decoder).eval() # type: ignore @torch.inference_mode() @@ -95,23 +115,43 @@ def predict( input: Union[Path, Sequence[str]], source_lang: str, target_lang: str, + source_mode: str = "source", + target_mode: str = "target", batch_size: int = 5, progress_bar: bool = False, **generator_kwargs, ) -> List[str]: # truncate the max seq len to avoid model to fail generator_kwargs = generator_kwargs or {} - model_max_seq_len = self.model.decoder.decoder_frontend.pos_encoder.max_seq_len # type: ignore[union-attr] - generator_kwargs["max_seq_len"] = min( - model_max_seq_len, generator_kwargs.get("max_seq_len", model_max_seq_len) + + model_max_seq_len = cast( + int | None, + ( + self.model.decoder.decoder_frontend.pos_encoder.max_seq_len # type: ignore + if self.model.decoder.decoder_frontend.pos_encoder is not None + else self.model.decoder.decoder.layers[ + 0 + ].self_attn.pos_encoder.max_seq_len + ), ) + if model_max_seq_len is None: + model_max_seq_len = generator_kwargs.get("max_seq_len", model_max_seq_len) - generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) + generator_kwargs["max_seq_len"] = min( + model_max_seq_len, + generator_kwargs.get("max_seq_len", model_max_seq_len), # type: ignore + ) + generator = BeamSearchSeq2SeqGenerator( + self.model, self.decoder_tokenizer.vocab_info, **generator_kwargs + ) translator = TextTranslator( generator, - tokenizer=self.tokenizer, + encoder_tokenizer=self.encoder_tokenizer, # type: ignore + decoder_tokenizer=self.decoder_tokenizer, # type: ignore source_lang=source_lang, target_lang=target_lang, + source_mode=source_mode, + target_mode=target_mode, ) def _do_translate(src_texts: List[str]) -> List[str]: @@ -139,30 +179,27 @@ def _do_translate(src_texts: List[str]) -> List[str]: class TextToEmbeddingModelPipeline(torch.nn.Module): model: SonarEncoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer def __init__( self, encoder: Union[str, SonarEncoderModel], - tokenizer: Union[str, TextTokenizer], + tokenizer: Union[str, Tokenizer], device: Device = CPU, dtype: Optional[DataType] = None, ) -> None: """ Args: encoder (Union[str, SonarEncoderModel]): either card name or model object - tokenizer (Union[str, TextTokenizer]): either card name or tokenizer object + tokenizer (Union[str, Tokenizer]): either card name or tokenizer object device (device, optional): Defaults to CPU. dtype (DataType, optional): The data type of the model parameters and buffers. """ super().__init__() - fairseq2.setup_fairseq2() if isinstance(encoder, str): - encoder_hub = get_sonar_text_encoder_hub() - encoder = encoder_hub.load(encoder, device=device, dtype=dtype) + encoder = load_model(encoder, device=device, dtype=dtype) # type: ignore if isinstance(tokenizer, str): - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load(tokenizer) + tokenizer = load_tokenizer(tokenizer) self.tokenizer = tokenizer @@ -174,7 +211,7 @@ def __init__( def predict( self, input: Union[Path, Sequence[str]], - source_lang: str, + source_lang: str | None = None, batch_size: Optional[int] = 5, batch_max_tokens: Optional[int] = None, max_seq_len: Optional[int] = None, @@ -199,7 +236,14 @@ def predict( tokenizer_encoder = self.tokenizer.create_encoder( lang=source_lang, device=self.device ) - model_max_len = cast(int | None, self.model.encoder_frontend.pos_encoder.max_seq_len) # type: ignore[union-attr] + model_max_len = cast( + int | None, + ( + self.model.encoder_frontend.pos_encoder.max_seq_len # type: ignore + if self.model.encoder_frontend.pos_encoder is not None # type: ignore + else self.model.encoder.layers[0].self_attn.pos_encoder.max_seq_len # type: ignore + ), + ) if max_seq_len is None: max_seq_len = model_max_len if max_seq_len is not None and model_max_len is not None: @@ -241,7 +285,7 @@ def truncate(x: torch.Tensor) -> torch.Tensor: .map(Collater(self.tokenizer.vocab_info.pad_idx)) .map(lambda x: extract_sequence_batch(x, self.device)) .prefetch(2) - .map(self.model) + .map(lambda x: self.model(*x)) .map(lambda x: x.sentence_embeddings.to(target_device or self.device)) .and_return() ) @@ -271,32 +315,29 @@ def truncate(x: torch.Tensor) -> torch.Tensor: class EmbeddingToTextModelPipeline(torch.nn.Module): model: SonarEncoderDecoderModel - tokenizer: TextTokenizer + tokenizer: Tokenizer def __init__( self, decoder: Union[str, ConditionalTransformerDecoderModel], - tokenizer: Union[str, TextTokenizer], + tokenizer: Union[str, Tokenizer], device: Device = CPU, dtype: Optional[DataType] = None, ) -> None: """ Args: decoder (Union[str, ConditionalTransformerDecoderModel]): either card name or model object - tokenizer (Union[str, TextTokenizer]): either card name or tokenizer object + tokenizer (Union[str, Tokenizer]): either card name or tokenizer object device (device, optional): Defaults to CPU. dtype (DataType, optional): The data type of the model parameters and buffers. """ super().__init__() - fairseq2.setup_fairseq2() if isinstance(decoder, str): - decoder_hub = get_sonar_text_decoder_hub() - decoder = decoder_hub.load(decoder, device=device, dtype=dtype) + decoder = load_model(decoder, device=device, dtype=dtype) # type: ignore if isinstance(tokenizer, str): - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load(tokenizer) + tokenizer = load_tokenizer(tokenizer) - encoder = DummyEncoderModel(decoder.model_dim) # type: ignore + encoder = DummyEncoderModel() # type: ignore self.device = device self.tokenizer = tokenizer @@ -307,6 +348,7 @@ def predict( self, inputs: torch.Tensor, target_lang: str, + target_mode: str = "target", batch_size: int = 5, progress_bar: bool = False, sampler: Optional[Sampler] = None, @@ -314,22 +356,25 @@ def predict( ) -> List[str]: if sampler is not None: generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator( - self.model, sampler, **generator_kwargs + self.model, self.tokenizer.vocab_info, sampler, **generator_kwargs ) else: - generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) + generator = BeamSearchSeq2SeqGenerator( + self.model, self.tokenizer.vocab_info, **generator_kwargs + ) converter = SequenceToTextConverter( generator, self.tokenizer, task="translation", target_lang=target_lang, + mode=target_mode, ) def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: - texts, _ = converter.batch_convert( - torch.stack(src_tensors).to(self.device), None - ) + seqs = torch.stack(src_tensors).to(self.device) + seqs_layout = BatchLayout.of(seqs) + texts, _ = converter.batch_convert(seqs, seqs_layout) return texts pipeline: Iterable = ( diff --git a/sonar/inference_pipelines/utils.py b/sonar/inference_pipelines/utils.py index 93dbf96..5463cb6 100644 --- a/sonar/inference_pipelines/utils.py +++ b/sonar/inference_pipelines/utils.py @@ -5,20 +5,39 @@ # LICENSE file in the root directory of this source tree. import math +from collections.abc import Sequence from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable, Optional, Tuple, Union, final -from fairseq2.data import SequenceData -from fairseq2.models.sequence import SequenceBatch -from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import Device +from fairseq2.data.data_pipeline import SequenceData +from fairseq2.data.tokenizers import TokenDecoder, TokenEncoder, Tokenizer +from fairseq2.device import Device +from fairseq2.error import InternalError +from fairseq2.generation import Seq2SeqGenerator, SequenceGeneratorOutput +from fairseq2.nn import BatchLayout +from fairseq2.nn.utils.module import maybe_infer_device +from fairseq2.nn.utils.padding import pad_seqs +from torch import Tensor from tqdm.auto import tqdm -def extract_sequence_batch(x: SequenceData, device: Device) -> SequenceBatch: - seqs, padding_mask = get_seqs_and_padding_mask(x, device=device) +def extract_sequence_batch( + x: SequenceData, device: Device +) -> Tuple[Tensor, BatchLayout]: + """ + Naive conversion from `SequenceData` to `SequenceBatch` without padding or packing. + Moving `x` to device for backward compatibility of this function definition. + + This was a call to deprecated `get_seqs_and_padding_mask` in fs2:v0.4.6. - return SequenceBatch(seqs, padding_mask) + Args: + x (SequenceData): holding sequences and their lengths + device (Device): the computing device (cuda, cpu, etc.) + Returns: + SequenceBatch: rewrapped `x` and moved to `device` + """ + seqs, seq_lens = x["seqs"].to(device), x["seq_lens"] + return seqs, BatchLayout.of(seqs, seq_lens) def add_progress_bar( @@ -44,3 +63,261 @@ def add_progress_bar( total = math.ceil(len(inputs) / batch_size) # type: ignore return tqdm(sequence, total=total, **kwargs) + + +@final +class SequenceToTextConverter: + """Converts source sequences to text.""" + + # cirquit: This is a carbon copy of fs2:v0.5 with additional `mode` + # parameter passed to the tokenizer encoder. Should be upstreamed. + _generator: Seq2SeqGenerator + _target_prefix_seq: Tensor + _text_decoder: TokenDecoder + + def __init__( + self, + generator: Seq2SeqGenerator, + tokenizer: Tokenizer, + task: str, + target_lang: str | None = None, + mode: str = "target", + skip_special_tokens: bool = True, + ) -> None: + """ + :param generator: + The sequence-to-sequence generator. + :param tokenizer: + The text tokenizer. + :param task: + The conversion task (e.g. translation, transcription). + :param target_lang: + The target language for conversion. + :param mode: + The mode in which to generate token indices. Typically, translation + tasks use ``mode`` to distinguish between different modes such as + 'source' or 'target'. + :param skip_special_tokens: + Whether the tokenizer decoder skips outputting special tokens like . + """ + self._generator = generator + + try: + device = maybe_infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex + + target_text_encoder = tokenizer.create_encoder( + task=task, lang=target_lang, mode=mode, device=device + ) + + # (S) + target_prefix_seq = target_text_encoder.prefix_indices + if target_prefix_seq is None: + raise ValueError( + "`tokenizer` must specify a prefix sequence for the target language." + ) + + self._target_prefix_seq = target_prefix_seq + self._text_decoder = tokenizer.create_decoder( + skip_special_tokens=skip_special_tokens + ) + + def __call__(self, source_seqs: Tensor) -> tuple[str, SequenceGeneratorOutput]: + """ + :param source_seqs: + The source sequence. *Shape:* :math:`(S,*)`, where :math:`S` is the + sequence length and :math:`*` is any number of sequence-specific + dimensions including none. + + :returns: + - The converted text. + - The output of the underlying sequence-to-sequence generator. + """ + source_seqs = source_seqs.unsqueeze(0) + source_seqs_layout = BatchLayout.of(source_seqs) + texts, generator_output = self._do_convert(source_seqs, source_seqs_layout) + + return texts[0], generator_output + + def batch_convert( + self, source_seqs: Tensor, source_seqs_layout: BatchLayout + ) -> tuple[list[str], SequenceGeneratorOutput]: + """ + :param source_seqs: + The source sequences. *Shape:* :math:`(N,S,*)`, where :math:`N` is + the batch size, :math:`S` is the sequence length, and :math:`*` is + any number of sequence-specific dimensions including none. + + :returns: + - The converted texts. + - The output of the underlying sequence-to-sequence generator. + """ + if len(source_seqs) == 0: + raise ValueError( + "`source_seqs` must contain at least one element, but is empty instead." + ) + + return self._do_convert(source_seqs, source_seqs_layout) + + def _do_convert( + self, + source_seqs: Tensor, + source_seqs_layout: BatchLayout, + ) -> tuple[list[str], SequenceGeneratorOutput]: + """A subclass should call this method for actual text conversion. + + :param source_seqs: + The source sequences. *Shape:* :math:`(N,S,*)`, where :math:`N` is + the batch size, :math:`S` is the sequence length, and :math:`*` is + any number of sequence-specific dimensions including none. + + :returns: + - The converted texts. + - The output of the underlying sequence-to-sequence generator. + """ + batch_size = source_seqs.size(0) + + # (S) -> (N, S) + target_prefix_seqs = self._target_prefix_seq.expand(batch_size, -1) + target_prefix_layout = BatchLayout.of(target_prefix_seqs) + + generator_output = self._generator( + source_seqs, source_seqs_layout, target_prefix_seqs, target_prefix_layout + ) + + texts: list[str] = [] + + for idx, hypotheses in enumerate(generator_output.hypotheses): + if len(hypotheses) == 0: + raise InternalError( + f"The sequence generator returned no hypothesis at index {idx}." + ) + + texts.append(self._text_decoder(hypotheses[0].seq)) + + return texts, generator_output + + +@final +class TextTranslator: + """Translates text from one language to another.""" + + # TODO: cirquit - this is a carbon copy of fs2:v0.5 TextTranslator except for + # - SequenceToTextConverter.skip_special_tokens=True and the source_mode + # - call to pad_seqs with self._pad_ixd which comes from the encoder_tokenizer.vocab_info + + _converter: SequenceToTextConverter + _pad_idx: int + _source_text_encoder: TokenEncoder + _max_source_len: int | None + + def __init__( + self, + generator: Seq2SeqGenerator, + encoder_tokenizer: Tokenizer, + decoder_tokenizer: Tokenizer, + source_lang: str | None = None, + target_lang: str | None = None, + source_mode: str = "source", # this was also added + target_mode: str = "target", + *, + max_source_len: int | None = None, + skip_special_tokens: bool = True, + ) -> None: + """ + :param generator: + The sequence-to-sequence generator. + :param tokenizer: + The text tokenizer. + :param source_lang: + The source language. + :param target_lang: + The target language. + :param max_source_len: + The maximum number of tokens above which the source sequence gets + truncated. + :param skip_special_tokens: + Whether the tokenizer decoder skips outputting special tokens like . + """ + task = "translation" + + self._converter = SequenceToTextConverter( + generator=generator, + tokenizer=decoder_tokenizer, + task=task, + target_lang=target_lang, + mode=target_mode, + skip_special_tokens=skip_special_tokens, + ) + + pad_idx = encoder_tokenizer.vocab_info.pad_idx + if pad_idx is None: + raise ValueError( + "``vocab_info` of `tokenizer` must have a PAD symbol defined." + ) + + self._pad_idx = pad_idx + + try: + device = maybe_infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex + + self._source_text_encoder = encoder_tokenizer.create_encoder( + task="translation", lang=source_lang, mode=source_mode, device=device + ) + + if max_source_len is not None and max_source_len <= 0: + raise ValueError( + f"`max_source_len` must be greater than or equal to 1, but is {max_source_len} instead." + ) + + self._max_source_len = max_source_len + + def __call__(self, source_text: str) -> tuple[str, SequenceGeneratorOutput]: + """ + :param source_text: + The text in the source language. + + :returns: + - The translated text. + - The output of the underlying sequence-to-sequence generator. + """ + source_seq = self._source_text_encoder(source_text) + + if self._max_source_len: + source_seq = source_seq[: self._max_source_len] + + return self._converter(source_seq) + + def batch_translate( + self, source_texts: Sequence[str] + ) -> tuple[list[str], SequenceGeneratorOutput]: + """ + :param source_texts: + The texts in the source language. + + :returns: + - The translated texts. + - The output of the underlying sequence-to-sequence generator. + """ + if len(source_texts) == 0: + raise ValueError( + "`source_texts` must contain at least one element, but is empty instead." + ) + + source_seq_list = [self._source_text_encoder(t) for t in source_texts] + + if self._max_source_len: + source_seq_list = [seq[: self._max_source_len] for seq in source_seq_list] + + source_seqs, source_seqs_layout = pad_seqs( + source_seq_list, pad_value=self._pad_idx + ) + + return self._converter.batch_convert(source_seqs, source_seqs_layout) diff --git a/sonar/models/blaser/__init__.py b/sonar/models/blaser/__init__.py index 597db68..0334a42 100644 --- a/sonar/models/blaser/__init__.py +++ b/sonar/models/blaser/__init__.py @@ -4,16 +4,19 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from sonar.models.blaser.checkpoint import ( + _convert_blaser_checkpoint as _convert_blaser_checkpoint, +) +from sonar.models.blaser.config import BLASER_FAMILY as BLASER_FAMILY from sonar.models.blaser.config import BlaserConfig as BlaserConfig from sonar.models.blaser.config import ( - register_blaser_configs as register_blaser_configs, + _register_blaser_configs as _register_blaser_configs, ) -from sonar.models.blaser.factory import create_blaser_model as create_blaser_model -from sonar.models.blaser.handler import BlaserModelHandler as BlaserModelHandler +from sonar.models.blaser.factory import _create_blaser_model as _create_blaser_model from sonar.models.blaser.model import BlaserModel as BlaserModel # isort: split from fairseq2.models import ModelHubAccessor -get_blaser_model_hub = ModelHubAccessor(BlaserModel, BlaserConfig) +get_blaser_model_hub = ModelHubAccessor(BLASER_FAMILY, BlaserModel, BlaserConfig) diff --git a/sonar/models/blaser/checkpoint.py b/sonar/models/blaser/checkpoint.py new file mode 100644 index 0000000..f749223 --- /dev/null +++ b/sonar/models/blaser/checkpoint.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast + +from sonar.models.blaser.config import BlaserConfig + + +def _convert_blaser_checkpoint( + state_dict: dict[str, object], config: BlaserConfig +) -> dict[str, object]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass + + return state_dict diff --git a/sonar/models/blaser/config.py b/sonar/models/blaser/config.py index 58f919e..21bb382 100644 --- a/sonar/models/blaser/config.py +++ b/sonar/models/blaser/config.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import List +from typing import Final, List -from fairseq2.context import RuntimeContext +from fairseq2.runtime.config_registry import ConfigRegistrar +from fairseq2.runtime.dependency import DependencyContainer from sonar.models.blaser.model import ACTIVATIONS, BLASER_INPUT_FORMS +BLASER_FAMILY: Final = "blaser" + @dataclass class BlaserConfig: @@ -35,10 +38,8 @@ def __post__init__(self): ) -def register_blaser_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(BlaserConfig) - - arch = registry.decorator +def _register_blaser_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, BlaserConfig) @arch("basic_ref") def basic_ref() -> BlaserConfig: diff --git a/sonar/models/blaser/factory.py b/sonar/models/blaser/factory.py index b7f59a0..5b9da64 100644 --- a/sonar/models/blaser/factory.py +++ b/sonar/models/blaser/factory.py @@ -10,5 +10,5 @@ from sonar.models.blaser.model import BlaserModel -def create_blaser_model(config: BlaserConfig) -> BlaserModel: +def _create_blaser_model(config: BlaserConfig) -> BlaserModel: return BlaserModel(**asdict(config)) diff --git a/sonar/models/blaser/handler.py b/sonar/models/blaser/handler.py deleted file mode 100644 index f5f0ecf..0000000 --- a/sonar/models/blaser/handler.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, final - -from fairseq2.models import AbstractModelHandler -from torch.nn import Module -from typing_extensions import override - -from sonar.models.blaser.config import BlaserConfig -from sonar.models.blaser.factory import create_blaser_model -from sonar.models.blaser.model import BlaserModel - - -@final -class BlaserModelHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "blaser" - - @override - @property - def kls(self) -> type[Module]: - return BlaserModel - - @override - def _create_model(self, config: object) -> Module: - config = cast(BlaserConfig, config) - - return create_blaser_model(config) - - @override - def _convert_checkpoint( - self, checkpoint: dict[str, object], config: object - ) -> dict[str, object]: - # Return directly if found fairseq2 attribute in state dict - if "model" in checkpoint: - return checkpoint - - # Othewise (the old checkpoint format), move the whole state dict to the "model" section - return {"model": checkpoint} diff --git a/sonar/models/blaser/loader.py b/sonar/models/blaser/loader.py index 84efe5f..4b23a84 100644 --- a/sonar/models/blaser/loader.py +++ b/sonar/models/blaser/loader.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import fairseq2 - from sonar.models.blaser import get_blaser_model_hub from sonar.models.blaser.model import BlaserModel @@ -13,9 +11,8 @@ def load_blaser_model(model_name: str) -> BlaserModel: """ This file exists purely for backward compatibility of the package interface! - Normally, the user is encouraged to call `setup_fairseq2` and `get_blaser_model_hub` on their own. + Normally, the user is encouraged to call `get_blaser_model_hub` on their own. """ - fairseq2.setup_fairseq2() model_hub = get_blaser_model_hub() - model = model_hub.load(model_name) + model = model_hub.load_model(model_name) return model diff --git a/sonar/models/encoder_model.py b/sonar/models/encoder_model.py index ff1df72..91e5dd0 100644 --- a/sonar/models/encoder_model.py +++ b/sonar/models/encoder_model.py @@ -6,10 +6,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional -from fairseq2.models.sequence import SequenceBatch -from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.batch_layout import BatchLayout from torch import Tensor from torch.nn import Module @@ -31,37 +29,21 @@ class SonarEncoderOutput: dimensionality of the model. """ - padding_mask: Optional[PaddingMask] - """Optional, the floating padding mask over sequences (-inf means masked element) - *Shape:* :math:`(N,S)`, where :math:`N` is the batch size, - :math:`S` is the sequence length. + encoded_seqs_layout: BatchLayout + """The batchlayout of the ``encoded_seqs``. Holds the information of sequence length, + optional padding and whether the batch is packed. """ class SonarEncoderModel(ABC, Module): """Abstract class for both speech and text SONAR encoder models""" - model_dim: int - - def __init__(self, model_dim: int) -> None: - """ - - :param model_dim: - The dimensionality of the model. - """ + def __init__(self) -> None: super().__init__() - self.model_dim = model_dim - @property def dtype(self): return next(self.parameters()).dtype @abstractmethod - def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: - """ - :param batch: - The batch of sequences to process. - :returns: - SonarEncoderOutput - """ + def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> SonarEncoderOutput: ... diff --git a/sonar/models/laser2_text/__init__.py b/sonar/models/laser2_text/__init__.py index d375829..f7b4f28 100644 --- a/sonar/models/laser2_text/__init__.py +++ b/sonar/models/laser2_text/__init__.py @@ -4,20 +4,33 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from types import NoneType + +from sonar.models.laser2_text.checkpoint import ( + _convert_laser2_checkpoint as _convert_laser2_checkpoint, +) +from sonar.models.laser2_text.config import LASER2_FAMILY as LASER2_FAMILY from sonar.models.laser2_text.config import Laser2Config as Laser2Config from sonar.models.laser2_text.config import ( - register_laser2_configs as register_laser2_configs, + _register_laser2_configs as _register_laser2_configs, ) -from sonar.models.laser2_text.handler import Laser2ModelHandler as Laser2ModelHandler -from sonar.models.laser2_text.handler import ( - Laser2TokenizerHandler as Laser2TokenizerHandler, +from sonar.models.laser2_text.factory import ( + _create_laser2_model as _create_laser2_model, ) from sonar.models.laser2_text.tokenizer import Laser2Tokenizer as Laser2Tokenizer +from sonar.models.laser2_text.tokenizer import ( + _load_laser2_tokenizer as _load_laser2_tokenizer, +) # isort: split +from fairseq2.data.tokenizers import TokenizerHubAccessor from fairseq2.models import ModelHubAccessor from sonar.nn.laser_lstm_encoder import LaserLstmEncoder -get_laser2_model_hub = ModelHubAccessor(LaserLstmEncoder, Laser2Config) +get_laser2_model_hub = ModelHubAccessor(LASER2_FAMILY, LaserLstmEncoder, Laser2Config) + +get_laser2_tokenizer_hub = TokenizerHubAccessor( + LASER2_FAMILY, Laser2Tokenizer, NoneType +) diff --git a/sonar/models/laser2_text/checkpoint.py b/sonar/models/laser2_text/checkpoint.py new file mode 100644 index 0000000..f8f3af1 --- /dev/null +++ b/sonar/models/laser2_text/checkpoint.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast + +from sonar.models.laser2_text.config import Laser2Config + + +def _convert_laser2_checkpoint( + state_dict: dict[str, object], config: Laser2Config +) -> dict[str, object]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass + + return state_dict diff --git a/sonar/models/laser2_text/config.py b/sonar/models/laser2_text/config.py index e271a41..8b6aa6e 100644 --- a/sonar/models/laser2_text/config.py +++ b/sonar/models/laser2_text/config.py @@ -5,8 +5,12 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Final -from fairseq2.context import RuntimeContext +from fairseq2.runtime.config_registry import ConfigRegistrar +from fairseq2.runtime.dependency import DependencyContainer + +LASER2_FAMILY: Final = "lstm" @dataclass @@ -20,10 +24,8 @@ class Laser2Config: padding_value: float = 0.0 -def register_laser2_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(Laser2Config) - - arch = registry.decorator +def _register_laser2_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, Laser2Config) @arch("laser2") def laser2() -> Laser2Config: diff --git a/sonar/models/laser2_text/factory.py b/sonar/models/laser2_text/factory.py new file mode 100644 index 0000000..f9ed917 --- /dev/null +++ b/sonar/models/laser2_text/factory.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from sonar.models.laser2_text.config import Laser2Config +from sonar.nn import LaserLstmEncoder + + +def _create_laser2_model(config: Laser2Config) -> LaserLstmEncoder: + return LaserLstmEncoder( + num_embeddings=config.vocabulary_size, + padding_idx=config.pad_idx, + embed_dim=config.model_dim, + hidden_size=config.hidden_size, + num_layers=config.num_layers, + bidirectional=config.bidirectional, + padding_value=config.padding_value, + ) diff --git a/sonar/models/laser2_text/handler.py b/sonar/models/laser2_text/handler.py deleted file mode 100644 index ad6194f..0000000 --- a/sonar/models/laser2_text/handler.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from pathlib import Path -from typing import cast, final - -from fairseq2.assets import AssetCard -from fairseq2.data.text.tokenizers import AbstractTextTokenizerHandler, TextTokenizer -from fairseq2.models import AbstractModelHandler -from torch.nn import Module -from typing_extensions import override - -from sonar.models.laser2_text.config import Laser2Config -from sonar.models.laser2_text.tokenizer import Laser2Tokenizer -from sonar.nn.laser_lstm_encoder import LaserLstmEncoder - - -@final -class Laser2ModelHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "lstm" - - @override - @property - def kls(self) -> type[Module]: - return LaserLstmEncoder - - @override - def _create_model(self, config: object) -> Module: - config = cast(Laser2Config, config) - - return LaserLstmEncoder( - num_embeddings=config.vocabulary_size, - padding_idx=config.pad_idx, - embed_dim=config.model_dim, - hidden_size=config.hidden_size, - num_layers=config.num_layers, - bidirectional=config.bidirectional, - padding_value=config.padding_value, - ) - - -@final -class Laser2TokenizerHandler(AbstractTextTokenizerHandler): - @override - @property - def family(self) -> str: - return "lstm" - - @override - def _load_tokenizer(self, path: Path, card: AssetCard) -> TextTokenizer: - return Laser2Tokenizer(path) diff --git a/sonar/models/laser2_text/tokenizer.py b/sonar/models/laser2_text/tokenizer.py index f7ba963..4198494 100644 --- a/sonar/models/laser2_text/tokenizer.py +++ b/sonar/models/laser2_text/tokenizer.py @@ -4,57 +4,60 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from pathlib import Path from typing import Optional, final import torch -from fairseq2.data.text.tokenizers import ( - AbstractTextTokenizer, - TextTokenDecoder, - TextTokenEncoder, +from fairseq2.data.tokenizers import ( + TokenDecoder, + TokenEncoder, + Tokenizer, + TokenizerModelError, + VocabularyInfo, ) -from fairseq2.data.text.tokenizers.sentencepiece import ( +from fairseq2.data.tokenizers.sentencepiece import ( SentencePieceDecoder, SentencePieceEncoder, SentencePieceModel, - vocab_info_from_sentencepiece, + get_sentencepiece_vocabulary_info, ) -from fairseq2.typing import Device, override +from fairseq2.device import Device +from fairseq2.error import OperationalError from torch import Tensor from typing_extensions import NoReturn @final -class Laser2Encoder(TextTokenEncoder): +class Laser2Encoder(TokenEncoder): def __init__(self, spm_encoder: SentencePieceEncoder) -> None: self.spm_encoder: SentencePieceEncoder = spm_encoder - @override def __call__(self, sentence: str) -> torch.Tensor: out = self.spm_encoder(sentence) return torch.where(out >= 3, out + 4, out) - @override def encode_as_tokens(self, text: str) -> NoReturn: raise RuntimeError("not implemented!") @property - @override def prefix_indices(self) -> Optional[Tensor]: return self.spm_encoder.prefix_indices @property - @override def suffix_indices(self) -> Optional[Tensor]: return self.spm_encoder.suffix_indices @final -class Laser2Tokenizer(AbstractTextTokenizer): +class Laser2Tokenizer(Tokenizer): """Represents the tokenizer used by S2T Transformer models.""" model: SentencePieceModel + _vocab_info: VocabularyInfo + # breaking styleguide to implement the vocab_info abstract property interface def __init__(self, path: Path) -> None: """ @@ -62,12 +65,8 @@ def __init__(self, path: Path) -> None: The pathname of the SentencePiece model file. """ self.model = SentencePieceModel(path, [""]) + self._vocab_info = get_sentencepiece_vocabulary_info(self.model) - vocab_info = vocab_info_from_sentencepiece(self.model) - - super().__init__(vocab_info) - - @override def create_encoder( self, *, @@ -86,12 +85,30 @@ def create_encoder( ) ) - @override def create_raw_encoder( self, *, device: Optional[Device] = None, pin_memory: bool = False - ) -> TextTokenEncoder: + ) -> TokenEncoder: return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory) - @override - def create_decoder(self) -> TextTokenDecoder: + def create_decoder(self, *, skip_special_tokens: bool = False) -> TokenDecoder: return SentencePieceDecoder(self.model) + + @property + def vocab_info(self) -> VocabularyInfo: + return self._vocab_info + + +def _load_laser2_tokenizer(path: Path, config: None = None) -> Tokenizer: + try: + model = Laser2Tokenizer(path) + except OSError as ex: + raise OperationalError( + f"A system error has occurred while reading the '{path}' tokenizer model. See the nested exception for details." + ) from ex + except RuntimeError as ex: + raise TokenizerModelError( + path, + f"The '{path}' tokenizer model cannot be loaded. See the nested exception for details.", # fmt: skip + ) from ex + + return model diff --git a/sonar/models/mutox/__init__.py b/sonar/models/mutox/__init__.py index fe91503..68acc7b 100644 --- a/sonar/models/mutox/__init__.py +++ b/sonar/models/mutox/__init__.py @@ -4,13 +4,18 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from sonar.models.mutox.config import MutoxConfig, register_mutox_configs -from sonar.models.mutox.factory import create_mutox_model -from sonar.models.mutox.handler import MutoxModelHandler -from sonar.models.mutox.model import MutoxClassifier +from sonar.models.mutox.checkpoint import ( + _convert_mutox_checkpoint as _convert_mutox_checkpoint, +) +from sonar.models.mutox.config import MUTOX_FAMILY as MUTOX_FAMILY +from sonar.models.mutox.config import MutoxConfig as MutoxConfig +from sonar.models.mutox.config import _register_mutox_configs as _register_mutox_configs +from sonar.models.mutox.factory import _create_mutox_model as _create_mutox_model # isort: split from fairseq2.models import ModelHubAccessor -get_mutox_model_hub = ModelHubAccessor(MutoxClassifier, MutoxConfig) +from sonar.models.mutox.model import MutoxClassifier as MutoxClassifier + +get_mutox_model_hub = ModelHubAccessor(MUTOX_FAMILY, MutoxClassifier, MutoxConfig) diff --git a/sonar/models/mutox/checkpoint.py b/sonar/models/mutox/checkpoint.py new file mode 100644 index 0000000..984ce1c --- /dev/null +++ b/sonar/models/mutox/checkpoint.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast + +from sonar.models.mutox.config import MutoxConfig + + +def _convert_mutox_checkpoint( + state_dict: dict[str, object], config: MutoxConfig +) -> dict[str, object]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass + + new_dict = {} + for key in state_dict: + if key.startswith("model_all."): + new_dict[key] = state_dict[key] + return new_dict diff --git a/sonar/models/mutox/config.py b/sonar/models/mutox/config.py index db51624..607a02f 100644 --- a/sonar/models/mutox/config.py +++ b/sonar/models/mutox/config.py @@ -5,8 +5,12 @@ # MIT_LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Final -from fairseq2.context import RuntimeContext +from fairseq2.runtime.config_registry import ConfigRegistrar +from fairseq2.runtime.dependency import DependencyContainer + +MUTOX_FAMILY: Final = "mutox_classifier" @dataclass @@ -17,10 +21,8 @@ class MutoxConfig: input_size: int -def register_mutox_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(MutoxConfig) - - arch = registry.decorator +def _register_mutox_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, MutoxConfig) @arch("mutox") def _base_mutox() -> MutoxConfig: diff --git a/sonar/models/mutox/factory.py b/sonar/models/mutox/factory.py index 3f2f3cd..2dc20ec 100644 --- a/sonar/models/mutox/factory.py +++ b/sonar/models/mutox/factory.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import asdict from torch import nn @@ -12,7 +11,7 @@ from sonar.models.mutox.model import MutoxClassifier -def create_mutox_model(config: MutoxConfig) -> MutoxClassifier: +def _create_mutox_model(config: MutoxConfig) -> MutoxClassifier: # TODO: refactor the model and the config to make this more flexible model_h1 = nn.Sequential( nn.Dropout(0.01), diff --git a/sonar/models/mutox/handler.py b/sonar/models/mutox/handler.py deleted file mode 100644 index dd95332..0000000 --- a/sonar/models/mutox/handler.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, final - -from fairseq2.models import AbstractModelHandler -from torch.nn import Module -from typing_extensions import override - -from sonar.models.mutox.config import MutoxConfig -from sonar.models.mutox.factory import create_mutox_model -from sonar.models.mutox.model import MutoxClassifier - - -@final -class MutoxModelHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "mutox_classifier" - - @override - @property - def kls(self) -> type[Module]: - return MutoxClassifier - - @override - def _create_model(self, config: object) -> Module: - config = cast(MutoxConfig, config) - - return create_mutox_model(config) - - @override - def _convert_checkpoint( - self, checkpoint: dict[str, object], config: object - ) -> dict[str, object]: - new_dict = {} - for key in checkpoint: - if key.startswith("model_all."): - new_dict[key] = checkpoint[key] - return {"model": new_dict} diff --git a/sonar/models/mutox/loader.py b/sonar/models/mutox/loader.py index a432088..3ca75bc 100644 --- a/sonar/models/mutox/loader.py +++ b/sonar/models/mutox/loader.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import fairseq2 - from sonar.models.mutox import get_mutox_model_hub from sonar.models.mutox.model import MutoxClassifier @@ -13,9 +11,8 @@ def load_mutox_model(model_name: str, device=None, dtype=None) -> MutoxClassifier: """ This file exists purely for backward compatibility of the package interface! - Normally, the user is encouraged to call `setup_fairseq2` and `get_blaser_model_hub` on their own. + Normally, the user is encouraged to call `get_mutox_model_hub` on their own. """ - fairseq2.setup_fairseq2() model_hub = get_mutox_model_hub() - model = model_hub.load(model_name).to(device=device, dtype=dtype) + model = model_hub.load_model(model_name).to(device=device, dtype=dtype) return model diff --git a/sonar/models/sonar_speech/__init__.py b/sonar/models/sonar_speech/__init__.py index 8facd04..38de37a 100644 --- a/sonar/models/sonar_speech/__init__.py +++ b/sonar/models/sonar_speech/__init__.py @@ -4,26 +4,31 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from sonar.models.sonar_speech.checkpoint import ( + _convert_sonar_speech_checkpoint as _convert_sonar_speech_checkpoint, +) +from sonar.models.sonar_speech.config import SONAR_SPEECH_FAMILY as SONAR_SPEECH_FAMILY from sonar.models.sonar_speech.config import ( SonarSpeechEncoderConfig as SonarSpeechEncoderConfig, ) from sonar.models.sonar_speech.config import ( - register_sonar_speech_encoder_configs as register_sonar_speech_encoder_configs, + _register_sonar_speech_encoder_configs as _register_sonar_speech_encoder_configs, ) from sonar.models.sonar_speech.factory import ( SonarSpeechEncoderFactory as SonarSpeechEncoderFactory, ) -from sonar.models.sonar_speech.handler import ( - SonarSpeechEncoderHandler as SonarSpeechEncoderHandler, -) -from sonar.models.sonar_speech.model import ( - SonarSpeechEncoderModel as SonarSpeechEncoderModel, +from sonar.models.sonar_speech.factory import ( + _create_sonar_speech_encoder_model as _create_sonar_speech_encoder_model, ) # isort: split from fairseq2.models import ModelHubAccessor +from sonar.models.sonar_speech.model import ( + SonarSpeechEncoderModel as SonarSpeechEncoderModel, +) + get_sonar_speech_encoder_hub = ModelHubAccessor( - SonarSpeechEncoderModel, SonarSpeechEncoderConfig + SONAR_SPEECH_FAMILY, SonarSpeechEncoderModel, SonarSpeechEncoderConfig ) diff --git a/sonar/models/sonar_speech/checkpoint.py b/sonar/models/sonar_speech/checkpoint.py new file mode 100644 index 0000000..edfb452 --- /dev/null +++ b/sonar/models/sonar_speech/checkpoint.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, cast + +from fairseq2.models.utils.checkpoint import convert_fairseq_state_dict + +from sonar.models.sonar_speech.config import SonarSpeechEncoderConfig + + +def _convert_sonar_speech_checkpoint( + state_dict: dict[str, Any], config: SonarSpeechEncoderConfig +) -> dict[str, Any]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass + + # Check if we have a fairseq2 checkpoint. + if "encoder_frontend.model_dim_proj" in state_dict: + return state_dict + + # assuming pre fs2:v0.5 formatting with top-level "model" key + # state_dict = checkpoint["model"] + if "encoder.w2v_model.mask_emb" in state_dict: + del state_dict["encoder.w2v_model.mask_emb"] + + if "encoder.w2v_model.encoder.pos_conv.0.bias" in state_dict: + del state_dict["encoder.w2v_model.encoder.pos_conv.0.bias"] + del state_dict["encoder.w2v_model.encoder.pos_conv.0.weight_g"] + del state_dict["encoder.w2v_model.encoder.pos_conv.0.weight_v"] + + key_map = { + # fmt: off + # encoder + r"^encoder.w2v_model.layer_norm\.": r"encoder_frontend.post_extract_layer_norm.", + r"^encoder.w2v_model.post_extract_proj\.": r"encoder_frontend.model_dim_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"encoder.layers.\1.conv.batch_norm.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"encoder.layers.\1.conv.depthwise_conv.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"encoder.layers.\1.conv_layer_norm.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"encoder.layers.\1.conv.pointwise_conv1.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"encoder.layers.\1.conv.pointwise_conv2.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"encoder.layers.\1.ffn\2_layer_norm.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"encoder.layers.\1.ffn\2.inner_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"encoder.layers.\1.ffn\2.output_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder.layers.\1.self_attn_layer_norm.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"encoder.layers.\1.self_attn.sdpa.r_proj.", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias", + r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.", + r"^encoder.w2v_model.encoder\.layer_norm\.": r"encoder.layer_norm.", + r"^decoder\.embed_tokens\.": r"encoder_pooler.decoder_frontend.embed.", + r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.self_attn_layer_norm.", + r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.self_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"encoder_pooler.decoder.layers.\1.self_attn.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn_layer_norm.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.", + r"^decoder\.layers\.([0-9]+)\.fc1\.": r"encoder_pooler.decoder.layers.\1.ffn.inner_proj.", + r"^decoder\.layers\.([0-9]+)\.fc2\.": r"encoder_pooler.decoder.layers.\1.ffn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder_pooler.decoder.layers.\1.ffn_layer_norm.", + r"^decoder\.embed_out": r"encoder_pooler.projection_out.weight", + # fmt: on + } + + # In normal circumstances, we should never encounter a `LayerNorm` when + # `use_conformer` is `True`. Unfortunately, the w2v-BERT pretraining in + # fairseq was accidentally run with a pre-LN encoder, and ended up with + # a redundant `LayerNorm` right after the Conformer blocks. We mitigate + # that issue here by moving that `LayerNorm` to the sonar block. + if config.w2v2_encoder_config.use_conformer: + key_map.update({r"^encoder.w2v_model.encoder\.layer_norm\.": r"layer_norm."}) + + return convert_fairseq_state_dict(state_dict, key_map) diff --git a/sonar/models/sonar_speech/config.py b/sonar/models/sonar_speech/config.py index 319d662..f2a1be5 100644 --- a/sonar/models/sonar_speech/config.py +++ b/sonar/models/sonar_speech/config.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional +from typing import Final, Optional -from fairseq2.context import RuntimeContext +from fairseq2.models.transformer import TransformerNormOrder from fairseq2.models.w2vbert import W2VBertConfig from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig -from fairseq2.nn.transformer import TransformerNormOrder +from fairseq2.runtime.config_registry import ConfigRegistrar, get_config +from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver + +SONAR_SPEECH_FAMILY: Final = "sonar_speech" @dataclass @@ -51,16 +54,12 @@ class SonarSpeechEncoderConfig: """The dropout probability in Transformer layers.""" -def register_sonar_speech_encoder_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(SonarSpeechEncoderConfig) - - arch = registry.decorator - - w2vbert_registry = context.get_config_registry(W2VBertConfig) +def _register_sonar_speech_encoder_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, SonarSpeechEncoderConfig) - @arch("english") - def basic() -> SonarSpeechEncoderConfig: - w2vbert_config = w2vbert_registry.get("600m") + @arch("english", advanced=True) + def basic(resolver: DependencyResolver) -> SonarSpeechEncoderConfig: + w2vbert_config = get_config(resolver, W2VBertConfig, "600m") return SonarSpeechEncoderConfig( w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config, @@ -76,9 +75,9 @@ def basic() -> SonarSpeechEncoderConfig: dropout_p=0.1, ) - @arch("non_english") - def multilingual() -> SonarSpeechEncoderConfig: - w2vbert_config = w2vbert_registry.get("600m") + @arch("non_english", advanced=True) + def multilingual(resolver: DependencyResolver) -> SonarSpeechEncoderConfig: + w2vbert_config = get_config(resolver, W2VBertConfig, "600m") return SonarSpeechEncoderConfig( w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config, diff --git a/sonar/models/sonar_speech/factory.py b/sonar/models/sonar_speech/factory.py index 5592e76..e39c846 100644 --- a/sonar/models/sonar_speech/factory.py +++ b/sonar/models/sonar_speech/factory.py @@ -7,8 +7,19 @@ from typing import Optional from fairseq2.models.transformer import ( + FeedForwardNetwork, + IdentityBias, + MultiheadAttention, + StandardFeedForwardNetwork, + StandardMultiheadAttention, + StandardTransformerDecoder, + StandardTransformerDecoderLayer, + TransformerDecoder, + TransformerDecoderLayer, TransformerEmbeddingFrontend, + TransformerEncoder, TransformerFrontend, + create_default_sdpa, ) from fairseq2.models.wav2vec2 import Wav2Vec2EncoderFactory, Wav2Vec2Frontend from fairseq2.nn import ( @@ -18,27 +29,21 @@ PositionEncoder, SinusoidalPositionEncoder, StandardEmbedding, + StandardLayerNorm, init_scaled_embedding, ) -from fairseq2.nn.transformer import ( - FeedForwardNetwork, - MultiheadAttention, - StandardFeedForwardNetwork, - StandardMultiheadAttention, - StandardTransformerDecoder, - StandardTransformerDecoderLayer, - TransformerDecoder, - TransformerDecoderLayer, - TransformerEncoder, - create_default_sdpa, - create_standard_layer_norm, -) from sonar.models.sonar_speech.config import SonarSpeechEncoderConfig from sonar.models.sonar_speech.model import SonarSpeechEncoderModel from sonar.nn.encoder_pooler import AttentionEncoderOutputPooler, EncoderOutputPooler +def _create_sonar_speech_encoder_model( + config: SonarSpeechEncoderConfig, +) -> SonarSpeechEncoderModel: + return SonarSpeechEncoderFactory(config).create_model() + + class SonarSpeechEncoderFactory: config: SonarSpeechEncoderConfig @@ -80,6 +85,7 @@ def create_attention_pooler(self) -> EncoderOutputPooler: def create_decoder_frontend(self) -> TransformerFrontend: return TransformerEmbeddingFrontend( + self.config.model_dim, self.create_embedding(), self.create_pos_encoder(), dropout_p=self.config.dropout_p, @@ -94,7 +100,7 @@ def create_pos_encoder(self) -> PositionEncoder: def create_embedding(self) -> Embedding: return StandardEmbedding( num_embeddings=self.config.w2v2_encoder_config.model_dim, - embedding_dim=self.config.model_dim, + embed_dim=self.config.model_dim, pad_idx=self.config.pad_idx, init_fn=init_scaled_embedding, ) @@ -103,24 +109,24 @@ def create_decoder(self) -> TransformerDecoder: num_layers = self.config.num_decoder_layers layers = [self.create_decoder_layer() for _ in range(num_layers)] - return StandardTransformerDecoder( - layers, - norm_order=self.config.decoder_norm_order, - ) + return StandardTransformerDecoder(layers) def create_decoder_layer(self) -> TransformerDecoderLayer: num_heads = self.config.num_decoder_attn_heads return StandardTransformerDecoderLayer( - self.create_attention(num_heads), - self.create_attention(num_heads), - self.create_ffn(), + self_attn=self.create_attention(num_heads), + self_attn_layer_norm=self.create_layer_norm(), + encoder_decoder_attn=self.create_attention(num_heads), + encoder_decoder_attn_layer_norm=self.create_layer_norm(), + ffn=self.create_ffn(), + ffn_layer_norm=self.create_layer_norm(), dropout_p=self.config.dropout_p, norm_order=self.config.decoder_norm_order, ) def create_attention(self, num_heads: int) -> MultiheadAttention: - sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p) + sdpa = create_default_sdpa(bias=IdentityBias(), dropout_p=self.config.dropout_p) return StandardMultiheadAttention( self.config.model_dim, @@ -133,16 +139,16 @@ def create_ffn(self) -> FeedForwardNetwork: self.config.model_dim, self.config.ffn_inner_dim, bias=True, - norm_order=self.config.decoder_norm_order, ) + def create_layer_norm(self) -> LayerNorm: + model_dim = self.config.model_dim + return StandardLayerNorm(model_dim, bias=True) + def create_w2v2_final_layer_norm(self) -> Optional[LayerNorm]: if not self.config.w2v2_encoder_config.use_conformer: return None - - return create_standard_layer_norm( - self.config.w2v2_encoder_config.model_dim, - ) + return StandardLayerNorm(self.config.w2v2_encoder_config.model_dim, bias=True) def create_projection_out(self) -> Linear: return Linear( diff --git a/sonar/models/sonar_speech/handler.py b/sonar/models/sonar_speech/handler.py deleted file mode 100644 index 977f101..0000000 --- a/sonar/models/sonar_speech/handler.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, cast, final - -from fairseq2.models import AbstractModelHandler -from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint -from torch.nn import Module -from typing_extensions import override - -from sonar.models.sonar_speech.config import SonarSpeechEncoderConfig -from sonar.models.sonar_speech.factory import SonarSpeechEncoderFactory -from sonar.models.sonar_speech.model import SonarSpeechEncoderModel - - -@final -class SonarSpeechEncoderHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "sonar_speech" - - @override - @property - def kls(self) -> type[Module]: - return SonarSpeechEncoderModel - - @override - def _create_model(self, config: object) -> Module: - config = cast(SonarSpeechEncoderConfig, config) - - return SonarSpeechEncoderFactory(config).create_model() - - @override - def _convert_checkpoint( - self, checkpoint: dict[str, object], config: object - ) -> dict[str, object]: - config = cast(SonarSpeechEncoderConfig, config) - - return convert_sonar_speech_checkpoint(checkpoint, config) - - -def convert_sonar_speech_checkpoint( - checkpoint: dict[str, Any], config: SonarSpeechEncoderConfig -) -> dict[str, Any]: - state_dict = checkpoint["model"] - - # Check if we have a fairseq2 checkpoint. - if "encoder_frontend.model_dim_proj" in state_dict: - return checkpoint - - if "encoder.w2v_model.mask_emb" in state_dict: - del state_dict["encoder.w2v_model.mask_emb"] - - if "encoder.w2v_model.encoder.pos_conv.0.bias" in state_dict: - del state_dict["encoder.w2v_model.encoder.pos_conv.0.bias"] - del state_dict["encoder.w2v_model.encoder.pos_conv.0.weight_g"] - del state_dict["encoder.w2v_model.encoder.pos_conv.0.weight_v"] - - key_map = { - # fmt: off - # encoder - r"^encoder.w2v_model.layer_norm\.": r"encoder_frontend.post_extract_layer_norm.", - r"^encoder.w2v_model.post_extract_proj\.": r"encoder_frontend.model_dim_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"encoder.layers.\1.conv.batch_norm.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"encoder.layers.\1.conv.depthwise_conv.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"encoder.layers.\1.conv_layer_norm.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"encoder.layers.\1.conv.pointwise_conv1.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"encoder.layers.\1.conv.pointwise_conv2.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"encoder.layers.\1.ffn\2_layer_norm.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"encoder.layers.\1.ffn\2.inner_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"encoder.layers.\1.ffn\2.output_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder.layers.\1.self_attn_layer_norm.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"encoder.layers.\1.self_attn.sdpa.r_proj.", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias", - r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.", - r"^encoder.w2v_model.encoder\.layer_norm\.": r"encoder.layer_norm.", - - r"^decoder\.embed_tokens\.": r"encoder_pooler.decoder_frontend.embed.", - r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.self_attn_layer_norm.", - r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.self_attn.output_proj.", - r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"encoder_pooler.decoder.layers.\1.self_attn.", - r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn_layer_norm.", - r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.output_proj.", - r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.", - r"^decoder\.layers\.([0-9]+)\.fc1\.": r"encoder_pooler.decoder.layers.\1.ffn.inner_proj.", - r"^decoder\.layers\.([0-9]+)\.fc2\.": r"encoder_pooler.decoder.layers.\1.ffn.output_proj.", - r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder_pooler.decoder.layers.\1.ffn_layer_norm.", - - r"^decoder\.embed_out": r"encoder_pooler.projection_out.weight", - # fmt: on - } - - # In normal circumstances, we should never encounter a `LayerNorm` when - # `use_conformer` is `True`. Unfortunately, the w2v-BERT pretraining in - # fairseq was accidentally run with a pre-LN encoder, and ended up with - # a redundant `LayerNorm` right after the Conformer blocks. We mitigate - # that issue here by moving that `LayerNorm` to the sonar block. - if config.w2v2_encoder_config.use_conformer: - key_map.update({r"^encoder.w2v_model.encoder\.layer_norm\.": r"layer_norm."}) - - return convert_fairseq_checkpoint(checkpoint, key_map) diff --git a/sonar/models/sonar_speech/model.py b/sonar/models/sonar_speech/model.py index f3a4ea7..e22a5ee 100644 --- a/sonar/models/sonar_speech/model.py +++ b/sonar/models/sonar_speech/model.py @@ -4,15 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import Optional -from fairseq2.models.sequence import SequenceBatch -from fairseq2.models.transformer import TransformerFrontend -from fairseq2.nn import LayerNorm -from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.transformer import TransformerEncoder +from fairseq2.models.transformer import TransformerEncoder, TransformerFrontend +from fairseq2.nn import BatchLayout, LayerNorm from torch import Tensor from torch.nn import Dropout +from typing_extensions import override from sonar.models.encoder_model import SonarEncoderModel, SonarEncoderOutput from sonar.nn.encoder_pooler import EncoderOutputPooler @@ -20,7 +18,7 @@ class SonarSpeechEncoderModel(SonarEncoderModel): """Represents a SONAR speech encoder model as described in - # TODO add correct paper cite :cite:t`URL`.""" + :cite:t`https://doi.org/10.48550/arXiv.2308.11466`.""" encoder_frontend: TransformerFrontend encoder: TransformerEncoder @@ -48,7 +46,7 @@ def __init__( :param encoder_pooler: Encoder output pooler. """ - super().__init__(encoder.model_dim) + super().__init__() self.encoder_frontend = encoder_frontend self.encoder = encoder @@ -56,9 +54,10 @@ def __init__( self.layer_norm = layer_norm self.encoder_pooler = encoder_pooler - def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: - seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) - encoder_output, encoder_padding_mask = self.encoder(seqs, padding_mask) + @override + def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> SonarEncoderOutput: + seqs, seqs_layout = self.encoder_frontend(seqs, seqs_layout) + encoder_output = self.encoder(seqs, seqs_layout) # This is the workaround for the pre-LN issue of redundant LayerNorm. # We call here, to avoid fiddling with wav2vec2's model and config. @@ -66,21 +65,10 @@ def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: encoder_output = self.layer_norm(encoder_output) encoder_output = self.final_dropout(encoder_output) - encoder_output_pooled = self.encoder_pooler( - encoder_output, encoder_padding_mask - ) + encoder_output_pooled = self.encoder_pooler(encoder_output, seqs_layout) return SonarEncoderOutput( encoded_seqs=encoder_output, sentence_embeddings=encoder_output_pooled, - padding_mask=padding_mask, - ) - - def encode( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[Tensor]]: - sonar_output_encoder = self.encoder(seqs, padding_mask) - return ( - sonar_output_encoder.sentence_embeddings.unsqueeze(1), - None, + encoded_seqs_layout=seqs_layout, ) diff --git a/sonar/models/sonar_text/__init__.py b/sonar/models/sonar_text/__init__.py index 5223983..c402af5 100644 --- a/sonar/models/sonar_text/__init__.py +++ b/sonar/models/sonar_text/__init__.py @@ -4,6 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from fairseq2.models import ModelHubAccessor + +from sonar.models.sonar_text.checkpoint import ( + _convert_sonar_text_decoder_checkpoint as _convert_sonar_text_decoder_checkpoint, +) +from sonar.models.sonar_text.checkpoint import ( + _convert_sonar_text_encoder_checkpoint as _convert_sonar_text_encoder_checkpoint, +) +from sonar.models.sonar_text.config import ( + SONAR_TEXT_DECODER_FAMILY as SONAR_TEXT_DECODER_FAMILY, +) +from sonar.models.sonar_text.config import ( + SONAR_TEXT_ENCODER_FAMILY as SONAR_TEXT_ENCODER_FAMILY, +) from sonar.models.sonar_text.config import ( SonarTextDecoderConfig as SonarTextDecoderConfig, ) @@ -11,10 +25,10 @@ SonarTextEncoderConfig as SonarTextEncoderConfig, ) from sonar.models.sonar_text.config import ( - register_sonar_text_decoder_configs as register_sonar_text_decoder_configs, + _register_sonar_text_decoder_configs as _register_sonar_text_decoder_configs, ) from sonar.models.sonar_text.config import ( - register_sonar_text_encoder_configs as register_sonar_text_encoder_configs, + _register_sonar_text_encoder_configs as _register_sonar_text_encoder_configs, ) from sonar.models.sonar_text.factory import ( SonarTextDecoderFactory as SonarTextDecoderFactory, @@ -22,27 +36,23 @@ from sonar.models.sonar_text.factory import ( SonarTextEncoderFactory as SonarTextEncoderFactory, ) -from sonar.models.sonar_text.handler import ( - SonarTextDecoderHandler as SonarTextDecoderHandler, +from sonar.models.sonar_text.factory import ( + _create_sonar_text_decoder_model as _create_sonar_text_decoder_model, ) -from sonar.models.sonar_text.handler import ( - SonarTextEncoderHandler as SonarTextEncoderHandler, +from sonar.models.sonar_text.factory import ( + _create_sonar_text_encoder_model as _create_sonar_text_encoder_model, ) from sonar.models.sonar_text.model import ( SonarTextTransformerEncoderModel as SonarTextTransformerEncoderModel, ) - -# isort: split - -from fairseq2.models import ModelHubAccessor - from sonar.nn.conditional_decoder_model import ConditionalTransformerDecoderModel get_sonar_text_encoder_hub = ModelHubAccessor( - SonarTextTransformerEncoderModel, SonarTextEncoderConfig + SONAR_TEXT_ENCODER_FAMILY, SonarTextTransformerEncoderModel, SonarTextEncoderConfig ) - get_sonar_text_decoder_hub = ModelHubAccessor( - ConditionalTransformerDecoderModel, SonarTextDecoderConfig + SONAR_TEXT_DECODER_FAMILY, + ConditionalTransformerDecoderModel, + SonarTextDecoderConfig, ) diff --git a/sonar/models/sonar_text/handler.py b/sonar/models/sonar_text/checkpoint.py similarity index 59% rename from sonar/models/sonar_text/handler.py rename to sonar/models/sonar_text/checkpoint.py index 69a81f5..f4ad1aa 100644 --- a/sonar/models/sonar_text/handler.py +++ b/sonar/models/sonar_text/checkpoint.py @@ -4,60 +4,31 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, cast, final +from typing import Any, Dict, cast import torch -from fairseq2.models import AbstractModelHandler -from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint -from torch.nn import Module -from typing_extensions import override +from fairseq2.models.utils.checkpoint import convert_fairseq_state_dict from sonar.models.sonar_text.config import ( SonarTextDecoderConfig, SonarTextEncoderConfig, ) -from sonar.models.sonar_text.factory import ( - SonarTextDecoderFactory, - SonarTextEncoderFactory, -) -from sonar.models.sonar_text.model import SonarTextTransformerEncoderModel -from sonar.nn.conditional_decoder_model import ConditionalTransformerDecoderModel - - -@final -class SonarTextEncoderHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "transformer_encoder" - - @override - @property - def kls(self) -> type[Module]: - return SonarTextTransformerEncoderModel - @override - def _create_model(self, config: object) -> Module: - config = cast(SonarTextEncoderConfig, config) - - return SonarTextEncoderFactory(config).create_model() - - @override - def _convert_checkpoint( - self, checkpoint: dict[str, object], config: object - ) -> dict[str, object]: - return convert_sonar_text_encoder_checkpoint(checkpoint) +def _convert_sonar_text_encoder_checkpoint( + state_dict: Dict[str, Any], config: SonarTextEncoderConfig +) -> Dict[str, Any]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass -def convert_sonar_text_encoder_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: # Return directly if found fairseq2 attribute in state dict - if ( - "model" in checkpoint.keys() - and "encoder_frontend.embed.weight" in checkpoint["model"].keys() - ): - return checkpoint + if "encoder_frontend.embed.weight" in state_dict.keys(): + return state_dict - state_dict = checkpoint["state_dict"] + state_dict = state_dict["state_dict"] try: del state_dict["version"] @@ -81,9 +52,9 @@ def convert_sonar_text_encoder_checkpoint(checkpoint: dict[str, Any]) -> dict[st # fmt: on } - out_checkpoint = convert_fairseq_checkpoint(out_checkpoint, key_map) + out_checkpoint = convert_fairseq_state_dict(out_checkpoint, key_map) # type: ignore - embeds = checkpoint["embed_tokens"].weight + embeds = state_dict["embed_tokens"].weight # # The embedding positions of the control tokens do not match the # # SentencePiece model of the tokenizer. with torch.inference_mode(): @@ -94,47 +65,28 @@ def convert_sonar_text_encoder_checkpoint(checkpoint: dict[str, Any]) -> dict[st return out_checkpoint -@final -class SonarTextDecoderHandler(AbstractModelHandler): - @override - @property - def family(self) -> str: - return "transformer_decoder" - - @override - @property - def kls(self) -> type[Module]: - return ConditionalTransformerDecoderModel - - @override - def _create_model(self, config: object) -> Module: - config = cast(SonarTextDecoderConfig, config) - - return SonarTextDecoderFactory(config).create_model() - - @override - def _convert_checkpoint( - self, checkpoint: dict[str, object], config: object - ) -> dict[str, object]: - return convert_sonar_text_decoder_checkpoint(checkpoint) - +def _convert_sonar_text_decoder_checkpoint( + state_dict: dict[str, Any], config: SonarTextDecoderConfig +) -> dict[str, Any]: + # fairseq2 does not use a top-level "model" keyword anymore (v0.5+) + try: + state_dict = cast(dict[str, object], state_dict["model"]) + except KeyError: + pass -def convert_sonar_text_decoder_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: # Return directly if found fairseq2 attribute in state dict - if ( - "model" in checkpoint.keys() - and "decoder_frontend.embed.weight" in checkpoint["model"].keys() - ): - return checkpoint + if "decoder_frontend.embed.weight" in state_dict.keys(): + return state_dict - state_dict = checkpoint["state_dict"] + # assuming pre fs2:v0.5 formatting with top-level "model" key + state_dict = state_dict["state_dict"] try: del state_dict["version"] del state_dict["embed_positions._float_tensor"] except: pass - out_checkpoint = {"model": state_dict} + out_checkpoint = state_dict key_map = { r"layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.", @@ -158,15 +110,15 @@ def convert_sonar_text_decoder_checkpoint(checkpoint: dict[str, Any]) -> dict[st r"layer_norm.": r"decoder.layer_norm.", } - out_checkpoint = convert_fairseq_checkpoint(out_checkpoint, key_map) + out_checkpoint = convert_fairseq_state_dict(out_checkpoint, key_map) out_checkpoint = cast(dict[str, Any], out_checkpoint) - embeds = out_checkpoint["model"]["decoder_frontend.embed.weight"] + embeds = out_checkpoint["decoder_frontend.embed.weight"] # # The embedding positions of the control tokens do not match the # # SentencePiece model of the tokenizer. with torch.inference_mode(): # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] - out_checkpoint["model"]["decoder_frontend.embed.weight"] = embeds + out_checkpoint["decoder_frontend.embed.weight"] = embeds return out_checkpoint diff --git a/sonar/models/sonar_text/config.py b/sonar/models/sonar_text/config.py index 6a49f4a..aafdf45 100644 --- a/sonar/models/sonar_text/config.py +++ b/sonar/models/sonar_text/config.py @@ -5,10 +5,13 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional +from typing import Final, Optional -from fairseq2.context import RuntimeContext -from fairseq2.data import VocabularyInfo +from fairseq2.data.tokenizers import VocabularyInfo +from fairseq2.runtime.config_registry import ConfigRegistrar +from fairseq2.runtime.dependency import DependencyContainer + +SONAR_TEXT_ENCODER_FAMILY: Final = "transformer_encoder" @dataclass @@ -56,9 +59,6 @@ class SonarTextEncoderConfig: activation_fn: str = "ReLU" """ activation function to use in FeedForward network of Transformers; None corresponds to ReLu""" - layernorm_embedding: bool = False - """ If True, apply LayerNorm on sequence embeddings""" - no_scale_embedding: bool = False """if False, multiply sequence embeddings by sqrt(model_dim) before positional encoding""" @@ -84,10 +84,8 @@ class SonarTextEncoderConfig: """if True, do max_seq_len += pad_idx + 1 for retro-compatibgiility with fairseq trained models""" -def register_sonar_text_encoder_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(SonarTextEncoderConfig) - - arch = registry.decorator +def _register_sonar_text_encoder_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, SonarTextEncoderConfig) @arch("basic") def basic() -> SonarTextEncoderConfig: @@ -104,7 +102,6 @@ def basic() -> SonarTextEncoderConfig: max_seq_len=512, pooling="mean", no_token_positional_embeddings=False, - layernorm_embedding=False, activation_fn="ReLU", normalize_before=False, num_encoder_layers=24, @@ -127,6 +124,9 @@ def small(vocab_size=32005, depth=6, hidden_dim=1024 * 4) -> SonarTextEncoderCon return config +SONAR_TEXT_DECODER_FAMILY: Final = "transformer_decoder" + + @dataclass class SonarTextDecoderConfig: """Holds the configuration of an SonarDecoder model.""" @@ -145,9 +145,6 @@ class SonarTextDecoderConfig: activation_fn: str """ activation function to use in FeedForward network of Transformers; None corresponds to ReLu""" - layernorm_embedding: bool - """ If True, apply LayerNorm on sequence embeddings""" - no_scale_embedding: bool """if False, multiply sequence embeddings by sqrt(model_dim) before positional encoding""" @@ -189,10 +186,8 @@ class SonarTextDecoderConfig: """The dimensionality of the input. If None, model_dim is used instead.""" -def register_sonar_text_decoder_configs(context: RuntimeContext) -> None: - registry = context.get_config_registry(SonarTextDecoderConfig) - - arch = registry.decorator +def _register_sonar_text_decoder_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, SonarTextDecoderConfig) @arch("basic") def basic() -> SonarTextDecoderConfig: @@ -208,7 +203,6 @@ def basic() -> SonarTextDecoderConfig: attention_dropout_p=0.1, activation_dropout_p=0.1, no_token_positional_embeddings=False, - layernorm_embedding=False, activation_fn="ReLU", normalize_before=True, num_encoder_layers=24, @@ -244,7 +238,6 @@ def toy() -> SonarTextDecoderConfig: attention_dropout_p=0.1, activation_dropout_p=0.1, no_token_positional_embeddings=False, - layernorm_embedding=False, activation_fn="ReLU", normalize_before=True, num_encoder_layers=2, diff --git a/sonar/models/sonar_text/factory.py b/sonar/models/sonar_text/factory.py index 400f028..536c73d 100644 --- a/sonar/models/sonar_text/factory.py +++ b/sonar/models/sonar_text/factory.py @@ -8,21 +8,10 @@ import torch.nn from fairseq2.models.transformer import ( - TransformerEmbeddingFrontend, - TransformerFrontend, -) -from fairseq2.nn import ( - LearnedPositionEncoder, - Linear, - PositionEncoder, - SinusoidalPositionEncoder, - StandardEmbedding, - StandardLayerNorm, - TiedProjection, - init_scaled_embedding, -) -from fairseq2.nn.transformer import ( + AttentionBias, + CausalAttentionBias, FeedForwardNetwork, + IdentityBias, MultiheadAttention, StandardFeedForwardNetwork, StandardMultiheadAttention, @@ -32,10 +21,23 @@ StandardTransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer, + TransformerEmbeddingFrontend, TransformerEncoderLayer, + TransformerFrontend, TransformerNormOrder, create_default_sdpa, ) +from fairseq2.nn import ( + LayerNorm, + LearnedPositionEncoder, + Linear, + PositionEncoder, + SinusoidalPositionEncoder, + StandardEmbedding, + StandardLayerNorm, + TiedProjection, + init_scaled_embedding, +) from torch.nn import Parameter from sonar.models.sonar_text.config import ( @@ -47,6 +49,12 @@ from sonar.nn.encoder_pooler import AttentionEncoderOutputPooler, EncoderOutputPooler +def _create_sonar_text_encoder_model( + config: SonarTextEncoderConfig, +) -> SonarTextTransformerEncoderModel: + return SonarTextEncoderFactory(config).create_model() + + class SonarTextEncoderFactory: config: SonarTextEncoderConfig @@ -72,7 +80,7 @@ def embedding_dim(self) -> int: def create_model(self) -> SonarTextTransformerEncoderModel: embed = StandardEmbedding( num_embeddings=self.config.vocab_info.size, - embedding_dim=self.config.model_dim, + embed_dim=self.config.model_dim, pad_idx=self.config.vocab_info.pad_idx, init_fn=init_scaled_embedding, ) @@ -92,19 +100,17 @@ def create_model(self) -> SonarTextTransformerEncoderModel: ) embedding_frontend = TransformerEmbeddingFrontend( - embed, - pos_encoder, + model_dim=self.config.model_dim, + embed=embed, + pos_encoder=pos_encoder, no_scale=self.config.no_scale_embedding, - layer_norm=self.config.layernorm_embedding, dropout_p=self.config.emb_dropout_p, ) transformer_layers = [ self.create_encoder_layer() for _ in range(self.config.num_encoder_layers) ] - encoder = StandardTransformerEncoder( - transformer_layers, norm_order=self.transformer_normalize_order - ) + encoder = StandardTransformerEncoder(transformer_layers) pooling = getattr(Pooling, self.config.pooling.upper()) if pooling == Pooling.ATTENTION: pooler = self.create_attention_pooler() @@ -114,15 +120,18 @@ def create_model(self) -> SonarTextTransformerEncoderModel: return SonarTextTransformerEncoderModel( encoder_frontend=embedding_frontend, encoder=encoder, - layer_norm=StandardLayerNorm(self.config.model_dim, bias=True), + layer_norm=self.create_layer_norm(dim=self.config.model_dim), pooling=pooling, pooler=pooler, + max_source_seq_len=self.config.max_seq_len, ) def create_encoder_layer(self) -> TransformerEncoderLayer: return StandardTransformerEncoderLayer( self_attn=self.create_attention(), + self_attn_layer_norm=self.create_layer_norm(dim=self.config.model_dim), ffn=self.create_ffn(), + ffn_layer_norm=self.create_layer_norm(dim=self.config.model_dim), dropout_p=self.config.attention_dropout_p, norm_order=TransformerNormOrder.PRE, ) @@ -137,7 +146,9 @@ def create_attention( model_dim=model_dim or self.config.model_dim, kv_dim=kv_dim or self.config.model_dim, num_heads=num_heads or self.config.num_encoder_attn_heads, - sdpa=create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p), + sdpa=create_default_sdpa( + bias=IdentityBias(), dropout_p=self.config.attention_dropout_p + ), ) def create_ffn( @@ -149,7 +160,6 @@ def create_ffn( bias=True, inner_activation=getattr(torch.nn, self.config.activation_fn)(), inner_dropout_p=self.config.activation_dropout_p, - norm_order=self.transformer_normalize_order, ) def create_attention_pooler(self) -> EncoderOutputPooler: @@ -160,12 +170,15 @@ def create_attention_pooler(self) -> EncoderOutputPooler: bos_idx=0, ) + def create_layer_norm(self, dim: int) -> LayerNorm: + return StandardLayerNorm(dim, bias=True) + # This method, and all methods below, refer only to the attention pooler building. # The "decoder" is used for pooling the encoder representations in a smarter way def create_decoder_frontend(self) -> TransformerFrontend: - embedding = StandardEmbedding( + embed = StandardEmbedding( num_embeddings=1, - embedding_dim=self.embedding_dim, + embed_dim=self.embedding_dim, pad_idx=0, init_fn=init_scaled_embedding, ) @@ -174,18 +187,26 @@ def create_decoder_frontend(self) -> TransformerFrontend: max_seq_len=1, ) return TransformerEmbeddingFrontend( - embed=embedding, + model_dim=self.config.model_dim, + embed=embed, pos_encoder=pos_encoder, dropout_p=self.config.emb_dropout_p, ) def create_decoder(self) -> TransformerDecoder: - num_layers = self.config.num_decoder_layers - layers = [self.create_decoder_layer() for _ in range(num_layers)] + if self.transformer_normalize_order == TransformerNormOrder.PRE: + layer_norm = self.create_layer_norm( + dim=self.config.embedding_dim or self.config.model_dim + ) + else: + layer_norm = None return StandardTransformerDecoder( - layers, - norm_order=self.transformer_normalize_order, + layers=[ + self.create_decoder_layer() + for _ in range(self.config.num_decoder_layers) + ], + layer_norm=layer_norm, ) def create_decoder_layer(self) -> TransformerDecoderLayer: @@ -198,15 +219,24 @@ def create_decoder_layer(self) -> TransformerDecoderLayer: model_dim=self.embedding_dim, kv_dim=self.embedding_dim, ), + self_attn_layer_norm=self.create_layer_norm( + dim=self.config.embedding_dim or self.config.model_dim + ), encoder_decoder_attn=self.create_attention( num_heads=num_heads, model_dim=self.embedding_dim, kv_dim=self.config.model_dim, ), + encoder_decoder_attn_layer_norm=self.create_layer_norm( + dim=self.config.embedding_dim or self.config.model_dim + ), ffn=self.create_ffn( model_dim=self.embedding_dim, inner_dim=self.config.decoder_ffn_inner_dim, ), + ffn_layer_norm=self.create_layer_norm( + dim=self.config.embedding_dim or self.config.model_dim + ), dropout_p=self.config.attention_dropout_p, norm_order=self.transformer_normalize_order, ) @@ -226,6 +256,12 @@ def create_projection_in(self) -> Linear: ) +def _create_sonar_text_decoder_model( + config: SonarTextDecoderConfig, +) -> ConditionalTransformerDecoderModel: + return SonarTextDecoderFactory(config).create_model() + + class SonarTextDecoderFactory: config: SonarTextDecoderConfig @@ -241,7 +277,7 @@ def __init__(self, config: SonarTextDecoderConfig) -> None: def create_decoder_frontend(self) -> TransformerFrontend: embed = StandardEmbedding( num_embeddings=self.config.vocab_info.size, - embedding_dim=self.config.model_dim, + embed_dim=self.config.model_dim, pad_idx=self.config.vocab_info.pad_idx, init_fn=init_scaled_embedding, ) @@ -251,34 +287,48 @@ def create_decoder_frontend(self) -> TransformerFrontend: _legacy_pad_idx=self.config.vocab_info.pad_idx, ) return TransformerEmbeddingFrontend( - embed, - pos_encoder, + model_dim=self.config.model_dim, + embed=embed, + pos_encoder=pos_encoder, no_scale=self.config.no_scale_embedding, - layer_norm=self.config.layernorm_embedding, dropout_p=self.config.emb_dropout_p, ) + def create_layer_norm(self, dim: int) -> LayerNorm: + return StandardLayerNorm(dim, bias=True) + def create_decoder_layer(self) -> TransformerDecoderLayer: - self_attn = self.create_attention(kv_dim=self.config.model_dim) + self_attn = self.create_attention( + bias=CausalAttentionBias(), kv_dim=self.config.model_dim + ) - encoder_decoder_attn = self.create_attention(kv_dim=self.config.input_dim) + encoder_decoder_attn = self.create_attention( + bias=IdentityBias(), kv_dim=self.config.input_dim + ) ffn = self.create_ffn() return StandardTransformerDecoderLayer( - self_attn, - encoder_decoder_attn, - ffn, + self_attn=self_attn, + self_attn_layer_norm=self.create_layer_norm(dim=self.config.model_dim), + encoder_decoder_attn=encoder_decoder_attn, + encoder_decoder_attn_layer_norm=self.create_layer_norm( + dim=self.config.model_dim + ), + ffn=ffn, + ffn_layer_norm=self.create_layer_norm(dim=self.config.model_dim), dropout_p=self.config.attention_dropout_p, norm_order=TransformerNormOrder.PRE, ) - def create_attention(self, kv_dim=None) -> MultiheadAttention: + def create_attention(self, bias: AttentionBias, kv_dim=None) -> MultiheadAttention: return StandardMultiheadAttention( self.config.model_dim, self.config.num_encoder_attn_heads, kv_dim=kv_dim or self.config.model_dim, - sdpa=create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p), + sdpa=create_default_sdpa( + bias=bias, dropout_p=self.config.attention_dropout_p + ), ) def create_ffn(self) -> FeedForwardNetwork: @@ -288,16 +338,17 @@ def create_ffn(self) -> FeedForwardNetwork: bias=True, inner_activation=getattr(torch.nn, self.config.activation_fn)(), inner_dropout_p=self.config.activation_dropout_p, - norm_order=self.transformer_normalize_order, ) def create_decoder(self) -> TransformerDecoder: return StandardTransformerDecoder( - [ + layers=[ self.create_decoder_layer() for _ in range(self.config.num_decoder_layers) ], - norm_order=TransformerNormOrder.PRE, + layer_norm=self.create_layer_norm( + dim=self.config.model_dim + ), # equivalent to TransformerNormOrder.PRE for ConditionalTransformerDecoderModel: @@ -307,9 +358,8 @@ def create_model(self) -> ConditionalTransformerDecoderModel: final_proj = TiedProjection(weight=param, bias=None) return ConditionalTransformerDecoderModel( - decoder_frontend, - decoder, - final_proj, - self.config.max_seq_len, - self.config.vocab_info, + decoder_frontend=decoder_frontend, + decoder=decoder, + final_proj=final_proj, + max_target_seq_len=self.config.max_seq_len, ) diff --git a/sonar/models/sonar_text/model.py b/sonar/models/sonar_text/model.py index c875f24..dd3e7d4 100644 --- a/sonar/models/sonar_text/model.py +++ b/sonar/models/sonar_text/model.py @@ -8,13 +8,11 @@ from typing import Optional, final import torch -from fairseq2.models.sequence import SequenceBatch -from fairseq2.models.transformer import TransformerFrontend +from fairseq2.models.transformer import TransformerEncoder, TransformerFrontend from fairseq2.nn import LayerNorm -from fairseq2.nn.padding import PaddingMask, apply_padding_mask -from fairseq2.nn.transformer import TransformerEncoder -from fairseq2.typing import override +from fairseq2.nn.batch_layout import BatchLayout from torch import Tensor +from typing_extensions import override from sonar.models.encoder_model import SonarEncoderModel, SonarEncoderOutput from sonar.nn.encoder_pooler import EncoderOutputPooler @@ -36,6 +34,7 @@ def __init__( self, encoder_frontend: TransformerFrontend, encoder: TransformerEncoder, + max_source_seq_len: int, layer_norm: Optional[LayerNorm] = None, pooling: Pooling = Pooling.LAST, pooler: Optional[EncoderOutputPooler] = None, @@ -45,21 +44,23 @@ def __init__( The encoder frontend. :param encoder: The encoder. + :param max_source_seq_len: + The maximum sequence length the encoder can ingest. :param layer_norm: optional LayerNorm that is applied on encoder output """ - super().__init__(encoder.model_dim) - if encoder_frontend.model_dim != encoder.model_dim: - raise ValueError( - f"`model_dim` of `encoder_frontend` and `model_dim` of `encoder` must be equal, but are {encoder_frontend.model_dim} and {encoder.model_dim} instead." - ) - if ( - layer_norm is not None - and layer_norm.normalized_shape[0] != encoder.model_dim - ): - raise ValueError( - f"`model_dim` of `encoder` and `normalized_shape` of `layer_norm` must be equal, but are {encoder_frontend.model_dim} and {layer_norm.normalized_shape} instead." - ) + super().__init__() + # if encoder_frontend.model_dim != encoder.model_dim: + # raise ValueError( + # f"`model_dim` of `encoder_frontend` and `model_dim` of `encoder` must be equal, but are {encoder_frontend.model_dim} and {encoder.model_dim} instead." + # ) + # if ( + # layer_norm is not None + # and layer_norm.normalized_shape[0] != encoder.model_dim + # ): + # raise ValueError( + # f"`model_dim` of `encoder` and `normalized_shape` of `layer_norm` must be equal, but are {encoder_frontend.model_dim} and {layer_norm.normalized_shape} instead." + # ) self.encoder_frontend = encoder_frontend self.encoder = encoder self.layer_norm = layer_norm @@ -67,7 +68,7 @@ def __init__( self.pooler = pooler def pool( - self, seqs: Tensor, padding_mask: Optional[PaddingMask], pooling: Pooling + self, seqs: Tensor, seqs_layout: BatchLayout | None, pooling: Pooling ) -> Tensor: """Apply determininstic or trainable pooling""" if pooling == Pooling.ATTENTION: @@ -75,17 +76,17 @@ def pool( self.pooler is not None ), "Cannot use trainable pooling without a pooler in the model" sentence_embedding = self.pooler( - encoder_output=seqs, encoder_padding_mask=padding_mask + encoder_output=seqs, encoder_output_layout=seqs_layout ) else: sentence_embedding = self.static_pooling( - seqs=seqs, padding_mask=padding_mask, pooling=pooling + seqs=seqs, seqs_layout=seqs_layout, pooling=pooling ) return sentence_embedding @staticmethod def static_pooling( - seqs: Tensor, padding_mask: Optional[PaddingMask], pooling: Pooling + seqs: Tensor, seqs_layout: BatchLayout | None, pooling: Pooling ) -> Tensor: """Deterministic pooling along sequence dimension to get a sentence representation. In the future, some SONAR text encoders may have a trainable pooler instead. @@ -97,27 +98,32 @@ def static_pooling( Returns: Tensor: bs x model_dim """ + if pooling == Pooling.LAST: - if padding_mask is None: + if seqs_layout is None or (seqs_layout and not seqs_layout.padded): sentence_embedding = seqs[:, -1] else: - seq_lens = padding_mask.seq_lens + seq_lens = seqs_layout.seq_lens_pt sentence_embedding = seqs[ [torch.arange(seq_lens.shape[0]), (seq_lens - 1).clip_(0)] ] elif pooling == Pooling.MAX: - seqs = apply_padding_mask(seqs, padding_mask, pad_value=-torch.inf) + seqs = SonarTextTransformerEncoderModel.replace_padded_values( + seqs, seqs_layout, pad_value=-torch.inf + ) sentence_embedding = seqs.max(dim=1).values elif pooling == Pooling.MEAN: - seqs = apply_padding_mask(seqs, padding_mask, pad_value=0) + seqs = SonarTextTransformerEncoderModel.replace_padded_values( + seqs, seqs_layout, pad_value=0 + ) sentence_embedding = seqs.sum(dim=1) - if padding_mask is None: + if seqs_layout is None or not seqs_layout.padded: weights = 1.0 / (seqs.size(1) + 1e-7) sentence_embedding = sentence_embedding * weights else: weights = 1.0 / ( - padding_mask.seq_lens.to(sentence_embedding.dtype) + 1e-7 + seqs_layout.seq_lens_pt.to(sentence_embedding.dtype) + 1e-7 ) sentence_embedding = torch.einsum( "i...,i->i...", sentence_embedding, weights @@ -127,17 +133,50 @@ def static_pooling( return sentence_embedding + @staticmethod + def replace_padded_values( + seqs: Tensor, + seqs_layout: BatchLayout | None, + pad_value: int | float | Tensor = 0, + ) -> Tensor: + """Replace the padded values in ``seqs`` with `pad_value`. + + :param seqs: + The sequences to mask. *Shape:* :math:`(N,S,*)` or :math:`(B,*)` for packed, + where :math:`N` is the batch size, :math:`S` is the sequence length, and + :math:`*` is any number of sequence-specific dimensions including none. + :param seqs_layout: + The batch layout to apply. If None or not padded, returns seqs unchanged. + :param pad_value: + The value for padded positions. + + :returns: + The input sequences with mask applied. *Shape:* Same as ``seqs``. + """ + if seqs_layout is None or not seqs_layout.padded: + return seqs + + # True for valid positions, False for padding + mask = seqs_layout.position_indices >= 0 + + # Handle broadcasting for higher-dimensional tensors + for _ in range(seqs.ndim - mask.ndim): + mask = mask.unsqueeze(-1) + + return seqs.where(mask, pad_value) + @override - def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: - embed_seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> SonarEncoderOutput: + embed_seqs, embed_seqs_layout = self.encoder_frontend(seqs, seqs_layout) - encoded_seqs, _ = self.encoder(embed_seqs, padding_mask) + encoded_seqs = self.encoder(embed_seqs, embed_seqs_layout) + # encoded_seqs_layout = BatchLayout.of(encoded_seqs) if self.layer_norm is not None: encoded_seqs = self.layer_norm(encoded_seqs) - sentence_embeddings = self.pool(encoded_seqs, padding_mask, self.pooling) + sentence_embeddings = self.pool(encoded_seqs, embed_seqs_layout, self.pooling) return SonarEncoderOutput( encoded_seqs=encoded_seqs, sentence_embeddings=sentence_embeddings, - padding_mask=padding_mask, + encoded_seqs_layout=embed_seqs_layout, ) diff --git a/sonar/models/sonar_translation/__init__.py b/sonar/models/sonar_translation/__init__.py index b19bf20..61999c9 100644 --- a/sonar/models/sonar_translation/__init__.py +++ b/sonar/models/sonar_translation/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from sonar.models.sonar_translation.factory import ( create_sonar_speech_to_text_model as create_sonar_speech_to_text_model, ) diff --git a/sonar/models/sonar_translation/factory.py b/sonar/models/sonar_translation/factory.py index 697402d..f03b90e 100644 --- a/sonar/models/sonar_translation/factory.py +++ b/sonar/models/sonar_translation/factory.py @@ -6,7 +6,8 @@ from typing import Optional -from fairseq2.typing import DataType, Device +from fairseq2.data_type import DataType +from fairseq2.device import Device from sonar.models.sonar_speech import ( SonarSpeechEncoderConfig, diff --git a/sonar/models/sonar_translation/model.py b/sonar/models/sonar_translation/model.py index a9ebb92..4196549 100644 --- a/sonar/models/sonar_translation/model.py +++ b/sonar/models/sonar_translation/model.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple, final +from typing import final -from fairseq2.models.encoder_decoder import EncoderDecoderModel -from fairseq2.models.sequence import SequenceBatch, SequenceModelOutput -from fairseq2.nn.padding import PaddingMask +from fairseq2.models.seq2seq import Seq2SeqModel +from fairseq2.models.transformer.model import _TransformerModelState +from fairseq2.nn import BatchLayout, IncrementalStateBag from torch import Tensor from sonar.models.encoder_model import SonarEncoderModel, SonarEncoderOutput @@ -16,11 +16,20 @@ @final -class SonarEncoderDecoderModel(EncoderDecoderModel): - """Sonar translation model. +class SonarEncoderDecoderModel(Seq2SeqModel): + """Sonar translation model supporting two distinct usage patterns: - This is a generic model that can be used for speech any combination of speech,text - translation by combining Speech/Text Encoder/Decoder components. + 1. Sequence(Speech/Text)-to-Text (S2T): Real encoder transforms token sequences to sentence embeddings + 2. Embedding-to-Text (E2T): DummyEncoder passes pre-computed embeddings through + + Both patterns must produce identical encoder outputs for consistent decoder behavior. + The encode() method ensures compatibility by reshaping 2D sentence embeddings + [batch, embed_dim] to 3D [batch, 1, embed_dim] to match decoder expectations. + + Note (cirquit): This class is subclass of Seq2Seq but does not completely fit its mold, but is the closest we have to the previously + deprecated EncoderDecoderModel implementation ( None: - super().__init__( - encoder.model_dim, decoder.max_target_seq_len, decoder.target_vocab_info - ) - if encoder.model_dim != decoder.model_dim: - raise ValueError( - f"`model_dim` of `encoder` and `model_dim` of `decoder` must be equal, but are {encoder.model_dim} and {decoder.model_dim} instead." - ) + super().__init__(0, decoder.max_target_seq_len) # see note self.encoder = encoder self.decoder = decoder @@ -45,51 +48,93 @@ def __init__( def dtype(self): return next(self.parameters()).dtype - def encode( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: - batch = SequenceBatch(seqs, padding_mask) - sonar_output_encoder = self.encoder(batch) - return (sonar_output_encoder.sentence_embeddings.unsqueeze(1), None) + def encode(self, seqs: Tensor, seqs_layout: BatchLayout) -> Tensor: + """Convert input sequences to decoder-ready embeddings. + + Transforms variable inputs (tokens or embeddings) to standardized output format: + - Input: [batch, seq_len] tokens or [batch, embed_dim] embeddings + - Output: [batch, 1, embed_dim] embeddings with sequence dimension + + The unsqueeze(1) operation creates a sequence dimension that the decoder + interprets as a single timestep, ensuring consistent behavior across pipelines. + """ + encoder_output = self.encoder(seqs, seqs_layout) + return encoder_output.sentence_embeddings.unsqueeze(1) def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + seqs_layout: BatchLayout, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_output_layout: BatchLayout, state_bag=None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: - seqs, padding_mask = self.decoder.decoder_frontend( - seqs, padding_mask, state_bag=state_bag + ) -> Tensor: + return self.decoder.decode( # type: ignore[no-any-return] + seqs, + seqs_layout, + encoder_output, + encoder_output_layout, + state_bag=state_bag, ) - return self.decoder.decoder( # type: ignore[no-any-return] - seqs, - padding_mask, + def project(self, decoder_output: Tensor) -> Tensor: + return self.decoder.project(decoder_output) + + # TODO: figure out how typing should work with overload + def forward( # type: ignore + self, + source_seqs: Tensor, + source_seqs_layout: BatchLayout, + target_seqs: Tensor, + target_seqs_layout: BatchLayout, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + # Incremental decoding needs to be handled on a model-level since fs2:v0.5 + if not self.training and state_bag is not None: + state = state_bag.maybe_get_state(self, _TransformerModelState) + else: + state = None + + if state is None: + encoder_output = self.encode(source_seqs, source_seqs_layout) + encoder_output_layout = BatchLayout.of(encoder_output) + + if not self.training and state_bag is not None: + state = _TransformerModelState(encoder_output, encoder_output_layout) + + state_bag.set_state(self, state) + else: + encoder_output = state.encoder_output + + encoder_output_layout = state.encoder_output_layout + + del source_seqs + + decoder_output = self.decode( + target_seqs, + target_seqs_layout, encoder_output, - encoder_padding_mask, + encoder_output_layout, state_bag=state_bag, ) - def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] - ) -> SequenceModelOutput: - return self.decoder.project(decoder_output, decoder_padding_mask) + del target_seqs + + return self.project(decoder_output) class DummyEncoderModel(SonarEncoderModel): - """Abstract class for both speech and text SONAR encoder models which does not modify its inputs.""" + """Passthrough encoder enabling architecture reuse for pre-computed embeddings. - def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: - """ - :param batch: - The batch of sequences to process. - :returns: - SonarEncoderOutput - """ + Allows SonarEncoderDecoderModel to handle embedding-to-text generation + without architectural changes. Returns input embeddings unchanged, relying + on the parent encode() method for proper shape formatting for `sentence_embeddings`. + """ + + def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> SonarEncoderOutput: return SonarEncoderOutput( - encoded_seqs=batch.seqs, - sentence_embeddings=batch.seqs, # reduce in dim 1 - padding_mask=batch.padding_mask, + encoded_seqs=seqs, + sentence_embeddings=seqs, # see SonarEncoderDecoderModel note on the shape dimension expectation + encoded_seqs_layout=seqs_layout, ) diff --git a/sonar/nn/__init__.py b/sonar/nn/__init__.py index fff2a45..4c74665 100644 --- a/sonar/nn/__init__.py +++ b/sonar/nn/__init__.py @@ -3,3 +3,13 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + +from sonar.nn.conditional_decoder_model import ( + ConditionalTransformerDecoderModel as ConditionalTransformerDecoderModel, +) +from sonar.nn.encoder_pooler import ( + AttentionEncoderOutputPooler as AttentionEncoderOutputPooler, +) +from sonar.nn.encoder_pooler import EncoderOutputPooler as EncoderOutputPooler +from sonar.nn.laser_lstm_encoder import LaserLstmEncoder as LaserLstmEncoder +from sonar.nn.sequence import SequenceModelOutput as SequenceModelOutput diff --git a/sonar/nn/conditional_decoder_model.py b/sonar/nn/conditional_decoder_model.py index d26fa6e..0c492ae 100644 --- a/sonar/nn/conditional_decoder_model.py +++ b/sonar/nn/conditional_decoder_model.py @@ -8,27 +8,30 @@ Fairseq2 does not have a suitable model, because: - fairseq2.models.transformer.model.TransformerModel imperatively includes a transformer encoder. - fairseq2.models.decoder.DecoderModel does not expect any additional inputs. -ConditionalTransformerDecoderModel inherits from EncoderDecoderModel, so it is a sibling class to TransformerModel. +ConditionalTransformerDecoderModel inherits from Seq2SeqModel, so it is a sibling class to TransformerModel. + +After fs2:v0.5 upgrade: +ConditionalTransformerDecoderModel inherited from EncoderDecoderModel (removed from fs2) and is replaced by Seq2Seq. +This is unconventional as a Seq2Seq model holds both encoder and decoder, while this is only the decoder implementation. +A custom solution might be required (similar to the encoder in sonar/models/encoder_model.py). """ -from typing import Optional, Tuple +from typing import Optional -from fairseq2.data import VocabularyInfo -from fairseq2.models.encoder_decoder import EncoderDecoderModel -from fairseq2.models.sequence import SequenceModelOutput -from fairseq2.models.transformer import TransformerFrontend -from fairseq2.nn import IncrementalStateBag, Projection -from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.transformer import TransformerDecoder +import torch +from fairseq2.models.seq2seq import Seq2SeqModel +from fairseq2.models.transformer import TransformerDecoder, TransformerFrontend +from fairseq2.nn import BatchLayout, IncrementalStateBag, Projection from torch import Tensor -class ConditionalTransformerDecoderModel(EncoderDecoderModel): +class ConditionalTransformerDecoderModel(Seq2SeqModel): """Represents a Transformer-based decoder model conditional on the inputs from the encoder.""" decoder_frontend: TransformerFrontend decoder: TransformerDecoder final_proj: Projection + post_sentemb_proj: Projection | None def __init__( self, @@ -36,7 +39,8 @@ def __init__( decoder: TransformerDecoder, final_proj: Projection, max_target_seq_len: int, - target_vocab_info: VocabularyInfo, + normalize_emb: bool = False, + post_sentemb_proj: Projection | None = None, ) -> None: """ :param decoder_frontend: @@ -47,48 +51,81 @@ def __init__( The projection to apply to decoder outputs. :param max_target_seq_len: The maximum length of sequences produced by the model. - :param target_vocab_info: - The vocabulary information of sequences produced by the model. + :param normalize_emb: + Whether to normalize the embedding before passing it to the decoder. + :param post_sentemb_proj: + The projection to apply to the sentence embedding. """ - super().__init__(decoder.model_dim, max_target_seq_len, target_vocab_info) - + super().__init__(max_source_seq_len=0, max_target_seq_len=max_target_seq_len) + # NOTE: max_source_seq_len = 0 is a workaround due to Seq2Seq requiring *both* an encoder/decoder model and this is the wrong subclass self.decoder_frontend = decoder_frontend self.decoder = decoder - + self.post_sentemb_proj = post_sentemb_proj + self.normalize_emb = normalize_emb self.final_proj = final_proj - def encode( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + def encode(self, seqs: Tensor, seqs_layout: BatchLayout) -> Tensor: """The encoding just returns the inputs as is.""" - return seqs, padding_mask + if self.normalize_emb: + if seqs.dtype != torch.float32: + original_dtype = seqs.dtype + norm = torch.norm(seqs.float(), dim=-1, keepdim=True) + norm = torch.clamp(norm, min=1e-6) + seqs = seqs / norm.to(original_dtype) + else: + seqs = seqs / seqs.norm(dim=-1, keepdim=True) + return seqs def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + seqs_layout: BatchLayout, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_output_layout: BatchLayout, *, state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + ) -> Tensor: """Decoding is exactly the same as with fairseq2 TransformerModel""" - seqs, padding_mask = self.decoder_frontend( - seqs, padding_mask, state_bag=state_bag + seqs, seqs_layout = self.decoder_frontend( + seqs, seqs_layout, state_bag=state_bag ) + if self.post_sentemb_proj is not None: + encoder_output = self.post_sentemb_proj(encoder_output) + return self.decoder( # type: ignore[no-any-return] seqs, - padding_mask, + seqs_layout, encoder_output, - encoder_padding_mask, + encoder_output_layout, state_bag=state_bag, ) - def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] - ) -> SequenceModelOutput: + def project(self, decoder_output: Tensor) -> Tensor: """Projection is exactly the same as with fairseq2 TransformerModel""" - logits = self.final_proj(decoder_output) + return self.final_proj(decoder_output) + + def forward( # type: ignore + self, + source_seqs: Tensor, + source_seqs_layout: BatchLayout, + target_seqs: Tensor, + target_seqs_layout: BatchLayout, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + """Reference implementation from fs2:v0.4.3 EncoderDecoderModel using BatchLayout + The decoder frontend is not used here, c.f. https://github.com/facebookresearch/fairseq2/blob/v0.4.3/src/fairseq2/models/encoder_decoder.py#L42 + + Reasoning behind the API change from Seq2SeqBatch to more flat types to help torch.compile to trace + """ + encoder_output = self.encode(source_seqs, source_seqs_layout) + + decoder_output = self.decode( + target_seqs, + target_seqs_layout, + encoder_output, + source_seqs_layout, # encoder does not change padding if it existed and needs to be forwarded + ) - return SequenceModelOutput(logits, pad_idx=self.target_vocab_info.pad_idx) + return self.project(decoder_output) diff --git a/sonar/nn/encoder_pooler.py b/sonar/nn/encoder_pooler.py index e7c5e6e..b6c662d 100644 --- a/sonar/nn/encoder_pooler.py +++ b/sonar/nn/encoder_pooler.py @@ -5,16 +5,15 @@ # LICENSE file in the root directory of this source tree. from abc import abstractmethod -from typing import Optional import torch -from fairseq2.models.transformer import TransformerFrontend +from fairseq2.device import Device +from fairseq2.models.transformer import TransformerDecoder, TransformerFrontend from fairseq2.nn import Linear -from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.transformer import TransformerDecoder -from fairseq2.typing import Device, override +from fairseq2.nn.batch_layout import BatchLayout from torch import Tensor from torch.nn import Module +from typing_extensions import override class EncoderOutputPooler(Module): @@ -22,9 +21,7 @@ class EncoderOutputPooler(Module): @abstractmethod def __call__( - self, - encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + self, encoder_output: Tensor, encoder_output_layout: BatchLayout | None ) -> Tensor: """Apply pooling on encoder_output @@ -33,10 +30,6 @@ def __call__( :math:`(N,S_{enc},M)`, where :math:`N` is the batch size, :math:`S_{enc}` is the encoder output sequence length, and :math:`M` is the dimensionality of the model. - :param encoder_padding_mask: - The float padding mask of ``encoder_output``. *Shape:* - :math:`(N,S_{enc})`, where :math:`N` is the batch size and - :math:`S_{enc}` is the encoder output sequence length. :returns: The pooler output. *Shape:* :math:`(N,M)`, where :math:`N` is the @@ -68,16 +61,17 @@ def __init__( @override def __call__( - self, - encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + self, encoder_output: Tensor, encoder_output_layout: BatchLayout | None ) -> Tensor: seqs = self._get_pooling_tokens(encoder_output.shape[0], encoder_output.device) + seqs_layout = BatchLayout.of(seqs) + if encoder_output_layout is None: + encoder_output_layout = BatchLayout.of(encoder_output) - seqs, padding_mask = self.decoder_frontend(seqs, None) + seqs, seqs_layout = self.decoder_frontend(seqs, seqs_layout) - decoder_out, _ = self.decoder( - seqs, padding_mask, encoder_output, encoder_padding_mask + decoder_out = self.decoder( + seqs, seqs_layout, encoder_output, encoder_output_layout ) return self.projection_out(decoder_out).squeeze(1) diff --git a/sonar/nn/sequence.py b/sonar/nn/sequence.py new file mode 100644 index 0000000..7930ebb --- /dev/null +++ b/sonar/nn/sequence.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This class is added to keep functionality identical from fairseq2:v0.4.3. Upgrade at the time fairseq2:v0.5. +The loss calculation has been moved to the forward pass for more fusion potential adding a bit of performance. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor +from torch.nn.functional import log_softmax, nll_loss + + +@dataclass +class SequenceModelOutput: + """Holds the output of a sequence model.""" + + logits: Tensor + """The logits for next-step prediction. *Shape:* :math:`(N,S,T)`, where + :math:`N` is the batch size, :math:`S` is the sequence length, and :math:`T` + is the size of the vocabulary.""" + + pad_idx: int | None + """The index of the PAD symbols in the vocabulary.""" + + def compute_loss( + self, + targets: Tensor, + *, + loss_mask: Tensor | None = None, + ignore_prefix_size: int = 0, + label_smoothing: float = 0.0, + ) -> Tensor: + """Compute the NLL (negative log-likelihood) loss. + + :param targets: + The target indices. *Shape:* :math:`(N,S)`, where :math:`N` is the + batch size and :math:`S` is the sequence length. + :param loss_mask: + The loss mask that specifies the elements in ``targets`` that should + be used in the loss computation. All non-masked elements will be + ignored. *Shape:* Same as ``targets``. + :param ignore_prefix_size: + The number of steps from the beginning of the sequence that should + be ignored in the loss computation. + :param label_smoothing: + The amount of label smoothing to apply while computing the loss. + + :returns: + A scalar tensor representing the summed NLL loss. + """ + if ignore_prefix_size > 0: + logits = self.logits[:, ignore_prefix_size:, :] + else: + logits = self.logits + + if ignore_prefix_size > 0: + targets = targets[:, ignore_prefix_size:] + + # For numerical stability run in single precision. + # (N, S, T) + lprobs = log_softmax(logits, dim=-1, dtype=torch.float32) + + # sum: (), none: (N, S) + loss = nll_loss( + input=lprobs, + target=targets, + ignore_index=-100 if self.pad_idx is None else self.pad_idx, + reduction="sum" if loss_mask is None else "none", + ) + # TODO: support label_smoothing (nll_loss no longer supports it) + + if loss_mask is None: + return loss + + if ignore_prefix_size > 0: + loss_mask = loss_mask[:, ignore_prefix_size:] + + # () + return (loss * loss_mask).sum() diff --git a/tests/conftest.py b/tests/conftest.py index d33389e..3b9ad73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from fairseq2 import setup_fairseq2 +from fairseq2 import init_fairseq2 from pytest import Session def pytest_sessionstart(session: Session) -> None: - setup_fairseq2() + init_fairseq2() diff --git a/tests/integration_tests/data/audio_files/README.md b/tests/integration_tests/data/audio_files/README.md index a6f3f95..ed36241 100644 --- a/tests/integration_tests/data/audio_files/README.md +++ b/tests/integration_tests/data/audio_files/README.md @@ -11,4 +11,4 @@ You can access the FLEURS paper at https://arxiv.org/abs/2205.12446. Please cite url = {https://arxiv.org/abs/2205.12446}, year = {2022}, } -``` \ No newline at end of file +``` diff --git a/tests/integration_tests/test_blaser.py b/tests/integration_tests/test_blaser.py index 9694ee3..a40a124 100644 --- a/tests/integration_tests/test_blaser.py +++ b/tests/integration_tests/test_blaser.py @@ -13,7 +13,7 @@ def test_blaser2_ref(): """Compare predictions of a specific reference-based model with hardcoded expected values""" model_hub = get_blaser_model_hub() - blaser = model_hub.load("blaser_2_0_ref") + blaser = model_hub.load_model("blaser_2_0_ref") blaser.eval() emb = torch.zeros([1, 1024]) + 1 / 32 pred = blaser(src=emb, mt=emb, ref=emb).item() @@ -29,7 +29,7 @@ def test_blaser2_ref(): def test_blaser2_qe(): """Compare predictions of a specific referenceless model with hardcoded expected values""" model_hub = get_blaser_model_hub() - blaser = model_hub.load("blaser_2_0_qe") + blaser = model_hub.load_model("blaser_2_0_qe") blaser.eval() emb = torch.zeros([1, 1024]) + 1 / 32 pred = blaser(src=emb, mt=emb).item() diff --git a/tests/integration_tests/test_laser2_text.py b/tests/integration_tests/test_laser2_text.py index 6abc630..e83d8db 100644 --- a/tests/integration_tests/test_laser2_text.py +++ b/tests/integration_tests/test_laser2_text.py @@ -8,13 +8,12 @@ from pathlib import Path import torch -from fairseq2.data import Collater +from fairseq2.data.data_pipeline import Collater from fairseq2.data.text import read_text -from fairseq2.data.text.tokenizers import get_text_tokenizer_hub +from fairseq2.data.tokenizers import load_tokenizer +from fairseq2.models import load_model from torch.testing import assert_close -from sonar.models.laser2_text import get_laser2_model_hub - device = torch.device("cpu") sentences = [ @@ -26,13 +25,11 @@ def test_load_laser2_text() -> None: - model_hub = get_laser2_model_hub() - model = model_hub.load("laser2_text_encoder", device=device) + model = load_model("laser2_text_encoder", device=device) model.eval() - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load("laser2_text_encoder") + tokenizer = load_tokenizer("laser2_text_encoder") encoder = tokenizer.create_encoder() @@ -49,7 +46,7 @@ def test_load_laser2_text() -> None: tokenized_sentences = next(iter(pipeline)) embed_sentences = model( - tokenized_sentences["seqs"], tokenized_sentences["seq_lens"] + tokenized_sentences["seqs"], torch.Tensor(tokenized_sentences["seq_lens"]) ) embed_sentences_norm = torch.nn.functional.normalize(embed_sentences) actual_sim = torch.matmul(embed_sentences_norm, embed_sentences_norm.T) diff --git a/tests/integration_tests/test_mutox.py b/tests/integration_tests/test_mutox.py index 9b98d16..9b85f79 100644 --- a/tests/integration_tests/test_mutox.py +++ b/tests/integration_tests/test_mutox.py @@ -49,7 +49,7 @@ def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_o ) hub = get_mutox_model_hub() - classifier = hub.load("sonar_mutox", device=device, dtype=dtype).eval() + classifier = hub.load_model("sonar_mutox", device=device, dtype=dtype).eval() with torch.inference_mode(): embeddings = t2vec_model.predict(input_texts, source_lang=source_lang) @@ -111,7 +111,7 @@ def test_sonar_mutox_classifier_probability_integration( ) hub = get_mutox_model_hub() - classifier = hub.load("sonar_mutox", device=device, dtype=dtype).eval() + classifier = hub.load_model("sonar_mutox", device=device, dtype=dtype).eval() for text, lang, expected_prob in zip( input_texts, [source_lang] * len(input_texts), expected_probabilities diff --git a/tests/integration_tests/test_sonar_speech_encoder.py b/tests/integration_tests/test_sonar_speech_encoder.py index 21324b0..23eae9b 100644 --- a/tests/integration_tests/test_sonar_speech_encoder.py +++ b/tests/integration_tests/test_sonar_speech_encoder.py @@ -7,7 +7,7 @@ from pathlib import Path import torch -from fairseq2.data.text.tokenizers import get_text_tokenizer_hub +from fairseq2.models.nllb.hub import get_nllb_tokenizer_hub from torch.testing import assert_close from sonar.inference_pipelines import ( @@ -25,14 +25,15 @@ class TestSonarTextClass: encoder_hub = get_sonar_speech_encoder_hub() - encoder = encoder_hub.load("sonar_speech_encoder_eng", device=DEVICE) + encoder = encoder_hub.load_model("sonar_speech_encoder_eng", device=DEVICE) encoder.eval() - - tokenizer_hub = get_text_tokenizer_hub() - tokenizer = tokenizer_hub.load("text_sonar_basic_encoder") + # from sonar import setup_logging_here + # setup_logging_here() + tokenizer_hub = get_nllb_tokenizer_hub() + tokenizer = tokenizer_hub.load_tokenizer("text_sonar_basic_encoder") decoder_hub = get_sonar_text_decoder_hub() - decoder = decoder_hub.load("text_sonar_basic_decoder", device=DEVICE) + decoder = decoder_hub.load_model("text_sonar_basic_decoder", device=DEVICE) decoder.eval() params = SpeechInferenceParams( diff --git a/tests/integration_tests/test_text_sonar.py b/tests/integration_tests/test_text_sonar.py index 016440b..6fe3e2e 100644 --- a/tests/integration_tests/test_text_sonar.py +++ b/tests/integration_tests/test_text_sonar.py @@ -8,7 +8,7 @@ import pytest import torch -from fairseq2.models.sequence import SequenceBatch +from fairseq2.nn import BatchLayout from torch.testing import assert_close # type: ignore from sonar.inference_pipelines.text import ( @@ -28,7 +28,7 @@ class TestSonarTextClass: text2text = TextToTextModelPipeline( "text_sonar_basic_encoder", "text_sonar_basic_decoder", - "text_sonar_basic_encoder", + "text_sonar_basic_encoder", # name of the tokenizer ) vec2text = EmbeddingToTextModelPipeline( "text_sonar_basic_decoder", # name of the decoder @@ -60,45 +60,51 @@ def test_encode_long_text(self) -> None: @torch.inference_mode() def test_text_decoder_sonar(self) -> None: - eng_tokenizer_encoder = self.text2text.tokenizer.create_encoder(lang="eng_Latn") + eng_tokenizer_encoder = self.text2text.decoder_tokenizer.create_encoder( + lang="eng_Latn" + ) tokenized_seq = eng_tokenizer_encoder(self.eng_sentences[0]).unsqueeze(0) - batch = SequenceBatch(tokenized_seq, None) - encoded_vec = self.text2text.model.encoder(batch) + tokenized_seq_layout = BatchLayout.of(tokenized_seq) + encoded_vec = self.text2text.model.encoder(tokenized_seq, tokenized_seq_layout) decoder = self.text2text.model.decoder dummy_prev_output_tokens = torch.Tensor([[3, 333]]).int() - seqs, padding_mask = decoder.decoder_frontend( - dummy_prev_output_tokens, padding_mask=None + dummy_prev_output_layout = BatchLayout.of(dummy_prev_output_tokens) + seqs, seqs_layout = decoder.decoder_frontend( + dummy_prev_output_tokens, dummy_prev_output_layout ) - - decoder_output, decoder_padding_mask = decoder.decoder( + encoder_output = encoded_vec.sentence_embeddings.unsqueeze(1) + encoder_output_layout = BatchLayout.of(encoder_output) + decoder_output = decoder.decoder( seqs, - padding_mask, - encoder_output=encoded_vec.sentence_embeddings.unsqueeze(1), + seqs_layout, + encoder_output=encoder_output, + encoder_output_layout=encoder_output_layout, ) - decoder_output = decoder.project(decoder_output, decoder_padding_mask) - out = decoder_output.logits + + decoder_output = decoder.project(decoder_output) + assert_close( - out[0, 0, :4], + decoder_output[0, 0, :4], torch.Tensor([-1.4572, -2.7325, -1.0546, 0.7818]), rtol=1e-4, atol=1e-4, ) assert_close( - out[0, 0, -3:], + decoder_output[0, 0, -3:], torch.Tensor([0.8982, 0.4996, -0.1487]), rtol=1e-4, atol=1e-4, ) assert_close( - out[0, 1, :4], + decoder_output[0, 1, :4], torch.Tensor([2.4092, 6.9624, 3.6308, 9.4825]), rtol=1e-4, atol=1e-4, ) assert_close( - out[0, 1, -4:], + decoder_output[0, 1, -4:], torch.Tensor([3.8826, 3.8777, 3.2820, 3.3275]), rtol=1e-4, atol=1e-4, diff --git a/tests/unit_tests/test_blaser_inference.py b/tests/unit_tests/test_blaser_inference.py index e1e52e0..5c8ca9d 100644 --- a/tests/unit_tests/test_blaser_inference.py +++ b/tests/unit_tests/test_blaser_inference.py @@ -8,7 +8,7 @@ import torch from torch.testing import assert_close -from sonar.models.blaser import BlaserConfig, create_blaser_model +from sonar.models.blaser import BlaserConfig, _create_blaser_model @pytest.mark.parametrize("embedding_dim", [32, 1024]) @@ -16,7 +16,7 @@ def test_blaser_qe(embedding_dim, batch_size): """Testing that a BLASER-QE model can be created and runs""" config = BlaserConfig(input_form="QE", embedding_dim=embedding_dim) - blaser = create_blaser_model(config).eval() + blaser = _create_blaser_model(config).eval() embedding = torch.zeros([batch_size, embedding_dim]) # test that the forward method produces an expected shape @@ -33,7 +33,7 @@ def test_blaser_qe(embedding_dim, batch_size): def test_blaser_ref(embedding_dim, batch_size): """Testing that a model can be created and that forward returns a right shape""" config = BlaserConfig(input_form="COMET", embedding_dim=embedding_dim) - blaser = create_blaser_model(config) + blaser = _create_blaser_model(config).eval() embedding = torch.zeros([batch_size, embedding_dim]) # test that the forward method produces an expected shape @@ -50,7 +50,7 @@ def test_blaser_ref(embedding_dim, batch_size): def test_input_form(input_form, embedding_dim): """Testing that BLASER inputs are processed correctlyb""" config = BlaserConfig(input_form=input_form, embedding_dim=embedding_dim) - blaser = create_blaser_model(config) + blaser = _create_blaser_model(config).eval() # the input vectors are arbitrary; we are checking only how they are concatenated src = torch.arange(0, embedding_dim).unsqueeze(0) / embedding_dim mt = torch.cos(src) diff --git a/tests/unit_tests/test_low_dimension_text_models.py b/tests/unit_tests/test_low_dimension_text_models.py index e2d0e23..356b764 100644 --- a/tests/unit_tests/test_low_dimension_text_models.py +++ b/tests/unit_tests/test_low_dimension_text_models.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. import torch -from fairseq2.context import get_runtime_context -from fairseq2.models.seq2seq import Seq2SeqBatch -from fairseq2.models.sequence import SequenceBatch +from fairseq2.nn import BatchLayout +from fairseq2.runtime.config_registry import get_config +from fairseq2.runtime.dependency import get_dependency_resolver from sonar.models.sonar_text import ( SonarTextDecoderConfig, @@ -19,14 +19,12 @@ def test_low_dim_encoder(): """Test that an encoder with a hidden dimension lower than the embedding dimension can be created and called.""" - context = get_runtime_context() - - config_registry = context.get_config_registry(SonarTextEncoderConfig) + resolver = get_dependency_resolver() + cfg = get_config(resolver, SonarTextEncoderConfig, "basic") embed_dim = 256 batch_size = 3 - cfg = config_registry.get("basic") cfg.model_dim = 32 cfg.embedding_dim = embed_dim cfg.num_encoder_layers = 5 @@ -35,39 +33,34 @@ def test_low_dim_encoder(): model = SonarTextEncoderFactory(cfg).create_model() tokens = torch.tensor([[0, 1, 2, 3, 4]] * batch_size) - batch = SequenceBatch( - seqs=tokens, - padding_mask=None, - ) + tokens_layout = BatchLayout.of(tokens) with torch.inference_mode(): - output = model(batch) + output = model(tokens, tokens_layout) print(output.sentence_embeddings.shape) assert output.sentence_embeddings.shape == (batch_size, embed_dim) def test_low_dim_decoder(): """Test that a decoder with a hidden dimension lower than the embedding dimension can be created and called.""" - context = get_runtime_context() + resolver = get_dependency_resolver() - config_registry = context.get_config_registry(SonarTextDecoderConfig) + cfg = get_config(resolver, SonarTextDecoderConfig, "toy") embed_dim = 256 batch_size = 3 - cfg = config_registry.get("toy") cfg.model_dim = 32 cfg.input_dim = embed_dim model = SonarTextDecoderFactory(cfg).create_model() embeds = torch.rand([batch_size, 1, embed_dim]) prefix = torch.tensor([[0, 1, 2, 3, 4]] * batch_size) - batch = Seq2SeqBatch( - source_seqs=embeds, - source_padding_mask=None, - target_seqs=prefix, - target_padding_mask=None, - ) with torch.inference_mode(): - output = model(batch) + output = model( + source_seqs=embeds, + source_seqs_layout=BatchLayout.of(embeds), + target_seqs=prefix, + target_seqs_layout=BatchLayout.of(prefix), + ) - assert output.logits.shape == (batch_size, 5, cfg.vocab_info.size) + assert output.shape == (batch_size, 5, cfg.vocab_info.size) diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index b135376..1cb3949 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -13,8 +13,8 @@ from sonar.models.mutox import ( MutoxClassifier, MutoxConfig, - MutoxModelHandler, - create_mutox_model, + _convert_mutox_checkpoint, + _create_mutox_model, ) # Builder tests @@ -26,7 +26,7 @@ def test_mutox_classifier_builder(input_size, device, dtype): """Test MutoxClassifierBuilder initializes a model with correct configuration and dtype.""" config = MutoxConfig(input_size=input_size) - model = create_mutox_model(config).to(device=device, dtype=dtype) + model = _create_mutox_model(config).to(device=device, dtype=dtype) # Check if model layers are correctly initialized with shapes assert isinstance(model, nn.Module), "Model should be an instance of nn.Module" @@ -43,7 +43,7 @@ def test_mutox_classifier_builder(input_size, device, dtype): def test_create_mutox_model(input_size): """Test create_mutox_model function to confirm it creates a model with the specified config.""" config = MutoxConfig(input_size=input_size) - model = create_mutox_model(config).to(device=torch.device("cpu")) + model = _create_mutox_model(config).to(device=torch.device("cpu")) # Check if the created model has the expected structure and behavior test_input = torch.zeros((3, input_size)) @@ -114,14 +114,9 @@ def test_convert_mutox_checkpoint(): "non_model_key": torch.tensor([3.0]), } config = MutoxConfig(input_size=1024) - converted: Dict[str, torch.Tensor] = MutoxModelHandler._convert_checkpoint(None, checkpoint, config) # type: ignore + converted: Dict[str, torch.Tensor] = _convert_mutox_checkpoint(checkpoint, config) # type: ignore # Verify only 'model_all.' keys are retained in the converted dictionary - assert "model" in converted, "Converted checkpoint should contain a 'model' key" - assert ( - "model_all.layer1.weight" in converted["model"] - ), "Expected 'model_all.layer1.weight'" - assert ( - "model_all.layer1.bias" in converted["model"] - ), "Expected 'model_all.layer1.bias'" - assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'" + assert "model_all.layer1.weight" in converted, "Expected 'model_all.layer1.weight'" + assert "model_all.layer1.bias" in converted, "Expected 'model_all.layer1.bias'" + assert "non_model_key" not in converted, "Unexpected 'non_model_key'" diff --git a/tests/unit_tests/test_sonar_pooling.py b/tests/unit_tests/test_sonar_pooling.py index 873318e..6fbef39 100644 --- a/tests/unit_tests/test_sonar_pooling.py +++ b/tests/unit_tests/test_sonar_pooling.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.batch_layout import BatchLayout from torch.testing import assert_close # type: ignore from sonar.models.sonar_text.model import Pooling, SonarTextTransformerEncoderModel @@ -14,41 +14,44 @@ def test_pooling_max() -> None: - padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) + # padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) seqs = torch.Tensor( [[[7, 2], [3, 4], [10, 20]], [[-1, -2], [100, 1000], [-10, -20]]] ) + seqs_layout = BatchLayout.of(seqs, seq_lens=[2, 1]) expected = torch.Tensor([[7.0, 4.0], [-1.0, -2.0]]) - actual = pooling_method(seqs, padding_mask, Pooling.MAX) + actual = pooling_method(seqs, seqs_layout, Pooling.MAX) assert_close(expected, actual) - actual_extra = pooling_method(seqs.unsqueeze(3), padding_mask, Pooling.MAX) + actual_extra = pooling_method(seqs.unsqueeze(3), seqs_layout, Pooling.MAX) assert_close(expected.unsqueeze(2), actual_extra) def test_pooling_mean() -> None: - padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) + # padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) seqs = torch.Tensor( [[[7, 2], [3, 4], [10, 20]], [[-1, -2], [100, 1000], [-10, -20]]] ) + seqs_layout = BatchLayout.of(seqs, seq_lens=[2, 1]) expected = torch.Tensor([[5.0, 3.0], [-1.0, -2.0]]) - actual = pooling_method(seqs, padding_mask, Pooling.MEAN) + actual = pooling_method(seqs, seqs_layout, Pooling.MEAN) assert_close(expected, actual) - actual_extra = pooling_method(seqs.unsqueeze(3), padding_mask, Pooling.MEAN) + actual_extra = pooling_method(seqs.unsqueeze(3), seqs_layout, Pooling.MEAN) assert_close(expected.unsqueeze(2), actual_extra) def test_pooling_last() -> None: - padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) + # padding_mask = PaddingMask(torch.tensor([2, 1]), batch_seq_len=3) seqs = torch.Tensor( [[[7, 2], [3, 4], [10, 20]], [[-1, -2], [100, 1000], [-10, -20]]] ) + seqs_layout = BatchLayout.of(seqs, seq_lens=[2, 1]) expected = torch.Tensor([[3.0, 4.0], [-1.0, -2.0]]) - actual = pooling_method(seqs, padding_mask, Pooling.LAST) + actual = pooling_method(seqs, seqs_layout, Pooling.LAST) assert_close(expected, actual) - actual_extra = pooling_method(seqs.unsqueeze(3), padding_mask, Pooling.LAST) + actual_extra = pooling_method(seqs.unsqueeze(3), seqs_layout, Pooling.LAST) assert_close(expected.unsqueeze(2), actual_extra) diff --git a/tests/unit_tests/test_tied_weights.py b/tests/unit_tests/test_tied_weights.py index c768b87..d01c6e4 100644 --- a/tests/unit_tests/test_tied_weights.py +++ b/tests/unit_tests/test_tied_weights.py @@ -8,10 +8,17 @@ from pathlib import Path import torch -from fairseq2.assets import AssetCard, InProcAssetMetadataLoader, StandardAssetStore -from fairseq2.context import get_runtime_context +from fairseq2.assets import ( + AssetCard, + StandardAssetStore, + get_asset_store, + load_in_memory_asset_metadata, +) +from fairseq2.runtime.config_registry import get_config +from fairseq2.runtime.dependency import get_dependency_resolver from sonar.models.sonar_text import ( + SONAR_TEXT_DECODER_FAMILY, SonarTextDecoderConfig, SonarTextDecoderFactory, get_sonar_text_decoder_hub, @@ -32,21 +39,19 @@ def create_model_card( "model_arch": model_arch, "checkpoint": "file://" + checkpoint_path.as_posix(), } - metadata_loader = InProcAssetMetadataLoader([model_card_info]) - asset_store.metadata_providers.append(metadata_loader.load()) + metadata_provider = load_in_memory_asset_metadata("memory", [model_card_info]) + asset_store._metadata_providers.append(metadata_provider) return asset_store.retrieve_card(model_name) def test_tied_weight(): """Testing that the decoder input and ouput embeddings are tied after creating the model and after loading""" - context = get_runtime_context() - - config_registry = context.get_config_registry(SonarTextDecoderConfig) - config = config_registry.get("toy") + resolver = get_dependency_resolver() + config = get_config(resolver, SonarTextDecoderConfig, "toy") model = SonarTextDecoderFactory(config).create_model() - assert model.decoder_frontend.embed.weight is model.final_proj.weight + assert model.decoder_frontend.embed.weight is model.final_proj.weight # type: ignore # counting the parameters total_params = sum(p.numel() for p in model.parameters()) @@ -64,15 +69,16 @@ def test_tied_weight(): # now load the model using a standard loader, based on a card card = create_model_card( - context.asset_store, + get_asset_store(), # type: ignore checkpoint_path=filename, - model_type="transformer_decoder", + model_type=SONAR_TEXT_DECODER_FAMILY, model_arch="toy", ) + print(card) decoder_hub = get_sonar_text_decoder_hub() - model_new = decoder_hub.load(card) + model_new = decoder_hub.load_model(card) # test that the newly loaded model has the same weight tying as the original one total_params_new = sum(p.numel() for p in model_new.parameters()) assert total_params_new == total_params - assert model_new.decoder_frontend.embed.weight is model_new.final_proj.weight + assert model_new.decoder_frontend.embed.weight is model_new.final_proj.weight # type: ignore