diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 0e27f594f..afa8adecd 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -782,6 +782,16 @@ impl PyTokenizer { self.tokenizer.get_vocab(with_added_tokens) } + /// Get the extra tokens + /// + /// Returns: + /// :obj:`Dict[str, int]`: The vocabulary + #[pyo3(signature = ())] + #[pyo3(text_signature = "(self)")] + fn get_special_tokens_mapping(&self) -> Option<&HashMap>> { + self.tokenizer.get_special_tokens_mapping() + } + /// Get the underlying vocabulary /// /// Returns: @@ -1848,6 +1858,22 @@ impl PyTokenizer { fn set_decoder(&mut self, decoder: Option>) { self.tokenizer.with_decoder(decoder.map(|d| d.clone())); } + + /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer + #[getter] + fn get_eos_token(&self, py: Python<'_>) -> Option> { + self.tokenizer + .get_special_tokens_mapping() + .and_then(|token| token.get("eos_token")) + // into_pyobject -> Bound. Turn that into PyObject. + .map(|v| v.clone()) + } + + /// Set the :class:`~tokenizers.decoders.Decoder` + #[setter] + fn set_eos_token(&mut self, new_eos_token: Option) { + self.tokenizer.with_special_tokens_mapping(); + } } #[cfg(test)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index cedabeebc..931577437 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -20,9 +20,9 @@ use std::{ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::{special_tokens_mapping::SpecialTokensMapping, utils::iter::ResultShunt}; mod added_vocabulary; mod encoding; @@ -30,6 +30,7 @@ pub mod normalizer; pub mod pattern; pub mod pre_tokenizer; mod serialization; +pub mod special_tokens_mapping; // Re-export wrappers pub use crate::decoders::DecoderWrapper; @@ -293,6 +294,7 @@ pub struct TokenizerBuilder { truncation: Option, padding: Option, + special_tokens_mapping: Option, } impl Default for TokenizerBuilder @@ -327,6 +329,7 @@ where added_vocabulary: AddedVocabulary::new(), truncation: None, padding: None, + special_tokens_mapping: None, } } @@ -347,6 +350,7 @@ where added_vocabulary: self.added_vocabulary, truncation: self.truncation, padding: self.padding, + special_tokens_mapping: self.special_tokens_mapping, }) } @@ -404,6 +408,14 @@ where self.padding = padding; self } + + pub fn with_special_tokens_mapping( + mut self, + special_tokens_mapping: Option, + ) -> Self { + self.special_tokens_mapping = special_tokens_mapping; + self + } } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -480,6 +492,7 @@ where added_vocabulary: t.added_vocabulary, padding: t.padding, truncation: t.truncation, + special_tokens_mapping: t.special_tokens_mapping, }) } } @@ -524,6 +537,7 @@ pub struct TokenizerImpl { // General processing parameters truncation: Option, padding: Option, + special_tokens_mapping: Option, } impl TokenizerImpl @@ -547,6 +561,7 @@ where truncation: None, padding: None, + special_tokens_mapping: None, } } @@ -654,6 +669,25 @@ where self.padding.as_ref() } + /// Set the special_tokens_mapping + pub fn with_special_tokens_mapping( + &mut self, + special_tokens_mapping: Option, + ) -> &mut Self { + self.special_tokens_mapping = special_tokens_mapping; + self + } + + /// Get the currently set extra tokens + pub fn get_special_tokens_mapping(&self) -> Option<&SpecialTokensMapping> { + self.special_tokens_mapping.as_ref() + } + + /// Get the currently set extra tokens + pub fn get_extra_token_muts(&mut self) -> Option<&mut SpecialTokensMapping> { + self.special_tokens_mapping.as_mut() + } + /// Get a mutable reference to the currently set padding parameters pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> { self.padding.as_mut() diff --git a/tokenizers/src/tokenizer/serialization.rs b/tokenizers/src/tokenizer/serialization.rs index 7075bed8f..9849ea6e5 100644 --- a/tokenizers/src/tokenizer/serialization.rs +++ b/tokenizers/src/tokenizer/serialization.rs @@ -42,6 +42,7 @@ where tokenizer.serialize_field("post_processor", &self.post_processor)?; tokenizer.serialize_field("decoder", &self.decoder)?; tokenizer.serialize_field("model", &self.model)?; + tokenizer.serialize_field("special_tokens_mapping", &self.special_tokens_mapping)?; tokenizer.end() } @@ -63,6 +64,7 @@ where "Tokenizer", &[ "version", + "special_tokens_mapping", "truncation", "padding", "added_tokens", @@ -143,6 +145,9 @@ where "post_processor" => { builder = builder.with_post_processor(map.next_value()?); } + "special_tokens_mapping" => { + builder = builder.with_special_tokens_mapping(map.next_value()?); + } _ => {} }; } @@ -221,7 +226,8 @@ mod tests { "continuing_subword_prefix": "", "max_input_chars_per_word": 100, "vocab": {} - } + }, + "special_tokens_mapping": null }"#; let tokenizer = Tokenizer::from_str(tok_json).unwrap(); diff --git a/tokenizers/src/tokenizer/special_tokens_mapping.rs b/tokenizers/src/tokenizer/special_tokens_mapping.rs new file mode 100644 index 000000000..3c9fdc8b5 --- /dev/null +++ b/tokenizers/src/tokenizer/special_tokens_mapping.rs @@ -0,0 +1,20 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use serde::Serialize; + +#[derive(Debug, Clone, Serialize)] +// A struct that represents the mapping between standard special token names like +// `eos_token` or `bos_token` or `my_token` to the corresponding string tokens. +// +// We choose BTreeMap and set for ordered serialization + fast element check +// Supports updating one entry, the whole entry +// Example +pub struct SpecialTokensMapping { + inner: BTreeMap>, +} + +impl SpecialTokensMapping { + pub fn new(inner: BTreeMap>) -> Self { + Self { inner } + } +}