Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bindings/node/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ impl tk::Normalizer for Normalizer {
}
}

#[napi]
pub fn append_normalizer(append: String) -> Normalizer {
Normalizer {
normalizer: Some(Arc::new(RwLock::new(
tk::normalizers::append::Append::new(append).into(),
))),
}
}

#[napi]
pub fn prepend_normalizer(prepend: String) -> Normalizer {
Normalizer {
Expand Down
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/normalizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NFKC = normalizers.NFKC
Sequence = normalizers.Sequence
Lowercase = normalizers.Lowercase
Append = normalizers.Append
Prepend = normalizers.Prepend
Strip = normalizers.Strip
StripAccents = normalizers.StripAccents
Expand Down
41 changes: 41 additions & 0 deletions bindings/python/py_src/tokenizers/normalizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,47 @@ class Precompiled(Normalizer):
"""
pass

class Append(Normalizer):
"""
Append normalizer
"""
def __init__(self, append):
pass

def normalize(self, normalized):
"""
Normalize a :class:`~tokenizers.NormalizedString` in-place

This method allows to modify a :class:`~tokenizers.NormalizedString` to
keep track of the alignment information. If you just want to see the result
of the normalization on a raw string, you can use
:meth:`~tokenizers.normalizers.Normalizer.normalize_str`

Args:
normalized (:class:`~tokenizers.NormalizedString`):
The normalized string on which to apply this
:class:`~tokenizers.normalizers.Normalizer`
"""
pass

def normalize_str(self, sequence):
"""
Normalize the given string

This method provides a way to visualize the effect of a
:class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment
information. If you need to get/convert offsets, you can use
:meth:`~tokenizers.normalizers.Normalizer.normalize`

Args:
sequence (:obj:`str`):
A string to normalize

Returns:
:obj:`str`: A string after normalization
"""
pass

class Prepend(Normalizer):
"""
Prepend normalizer
Expand Down
29 changes: 28 additions & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::{
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace,
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Append, Prepend, Replace,
Strip, StripAccents, NFC, NFD, NFKC, NFKD,
};
use tk::{NormalizedString, Normalizer};
Expand Down Expand Up @@ -82,6 +82,10 @@ impl PyNormalizer {
.into_pyobject(py)?
.into_any()
.into(),
NormalizerWrapper::Append(_) => Py::new(py, (PyAppend {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))?
.into_pyobject(py)?
.into_any()
Expand Down Expand Up @@ -512,6 +516,28 @@ impl PyStrip {
}
}

/// Append normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Append")]
pub struct PyAppend {}
#[pymethods]
impl PyAppend {
#[getter]
fn get_append(self_: PyRef<Self>) -> String {
getter!(self_, Append, append)
}

#[setter]
fn set_append(self_: PyRef<Self>, append: String) {
setter!(self_, Append, append, append)
}

#[new]
#[pyo3(signature = (append="▁".to_string()), text_signature = "(self, append)")]
fn new(append: String) -> (Self, PyNormalizer) {
(PyAppend {}, Append::new(append).into())
}
}

/// Prepend normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Prepend")]
pub struct PyPrepend {}
Expand Down Expand Up @@ -807,6 +833,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyLowercase>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyStripAccents>()?;
m.add_class::<PyAppend>()?;
m.add_class::<PyPrepend>()?;
m.add_class::<PyByteLevel>()?;
m.add_class::<PyNmt>()?;
Expand Down
40 changes: 40 additions & 0 deletions tokenizers/src/normalizers/append.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct Append {
pub append: String,
}

impl Append {
pub fn new(append: String) -> Self {
Self { append }
}
}

impl Normalizer for Append {
/// Append the normalized string inplace
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
if !normalized.is_empty() {
normalized.append(&self.append);
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_append() {
let original = "Hello";
let normalized = "Hello▁";
assert_ne!(original, normalized);
let mut n = NormalizedString::from(original);
let append = Append::new("▁".to_string());
append.normalize(&mut n).unwrap();
assert_eq!(&n.get(), &normalized);
}
}
18 changes: 18 additions & 0 deletions tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bert;
pub mod byte_level;
pub mod precompiled;
pub mod append;
pub mod prepend;
pub mod replace;
pub mod strip;
Expand All @@ -9,6 +10,7 @@ pub mod utils;
pub use crate::normalizers::bert::BertNormalizer;
pub use crate::normalizers::byte_level::ByteLevel;
pub use crate::normalizers::precompiled::Precompiled;
pub use crate::normalizers::append::Append;
pub use crate::normalizers::prepend::Prepend;
pub use crate::normalizers::replace::Replace;
pub use crate::normalizers::strip::{Strip, StripAccents};
Expand All @@ -34,6 +36,7 @@ pub enum NormalizerWrapper {
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Append(Append),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
Expand Down Expand Up @@ -64,6 +67,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Nmt,
Precompiled,
Replace,
Append,
Prepend,
ByteLevel,
}
Expand All @@ -90,6 +94,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Append(Append),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
Expand Down Expand Up @@ -145,6 +150,9 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
EnumType::Replace => NormalizerWrapper::Replace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Append => NormalizerWrapper::Append(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Prepend => NormalizerWrapper::Prepend(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
Expand Down Expand Up @@ -173,6 +181,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
NormalizerUntagged::Append(bpe) => NormalizerWrapper::Append(bpe),
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
}
Expand All @@ -196,6 +205,7 @@ impl Normalizer for NormalizerWrapper {
Self::Nmt(lc) => lc.normalize(normalized),
Self::Precompiled(lc) => lc.normalize(normalized),
Self::Replace(lc) => lc.normalize(normalized),
Self::Append(lc) => lc.normalize(normalized),
Self::Prepend(lc) => lc.normalize(normalized),
Self::ByteLevel(lc) => lc.normalize(normalized),
}
Expand All @@ -214,6 +224,7 @@ impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
impl_enum_from!(Replace, NormalizerWrapper, Replace);
impl_enum_from!(Append, NormalizerWrapper, Append);
impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);

Expand All @@ -239,6 +250,13 @@ mod tests {
_ => panic!("Expected an error here"),
}

let json = r#"{"append":"a"}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
reconstructed.unwrap(),
NormalizerWrapper::Append(_)
));

let json = r#"{"prepend":"a"}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
Expand Down
Loading