Skip to content
Open
4 changes: 2 additions & 2 deletions bindings/node/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::trainers::Trainer;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
Expand Down Expand Up @@ -95,7 +95,7 @@ impl tk::Model for Model {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self
.model
.as_ref()
Expand Down
4 changes: 2 additions & 2 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer;
use crate::processors::Processor;
use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask};
use crate::trainers::Trainer;
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use tokenizers::Model as ModelTrait;

use napi::bindgen_prelude::*;
Expand Down Expand Up @@ -433,7 +433,7 @@ impl Tokenizer {
}

#[napi]
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> HashMap<String, u32> {
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> FxHashMap<String, u32> {
let with_added_tokens = with_added_tokens.unwrap_or(true);
self.tokenizer.read().unwrap().get_vocab(with_added_tokens)
}
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.23"
ndarray = "0.16"
itertools = "0.12"
rustc-hash = "2.1.1"
compact_str = { version = "0.8.1", features = ["serde"] }

[dependencies.tokenizers]
path = "../../tokenizers"
Expand Down
94 changes: 68 additions & 26 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock};
use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
use compact_str::ToCompactString;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
Expand Down Expand Up @@ -91,7 +92,10 @@ impl PyDecoder {
}

impl Decoder for PyDecoder {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
self.decoder.decode_chain(tokens)
}
}
Expand Down Expand Up @@ -139,7 +143,12 @@ impl PyDecoder {
/// :obj:`str`: The decoded string
#[pyo3(text_signature = "(self, tokens)")]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
ToPyResult(
self.decoder
.decode(tokens)
.map(|t| t.to_compact_string().to_string()),
)
.into()
}

fn __repr__(&self) -> PyResult<String> {
Expand Down Expand Up @@ -235,12 +244,12 @@ pub struct PyWordPieceDec {}
impl PyWordPieceDec {
#[getter]
fn get_prefix(self_: PyRef<Self>) -> String {
getter!(self_, WordPiece, prefix.clone())
getter!(self_, WordPiece, prefix.clone().to_string())
}

#[setter]
fn set_prefix(self_: PyRef<Self>, prefix: String) {
setter!(self_, WordPiece, prefix, prefix);
setter!(self_, WordPiece, prefix, prefix.to_compact_string());
}

#[getter]
Expand All @@ -256,7 +265,10 @@ impl PyWordPieceDec {
#[new]
#[pyo3(signature = (prefix = String::from("##"), cleanup = true), text_signature = "(self, prefix=\"##\", cleanup=True)")]
fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) {
(PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into())
(
PyWordPieceDec {},
WordPiece::new(prefix.to_compact_string(), cleanup).into(),
)
}
}

Expand Down Expand Up @@ -412,12 +424,12 @@ pub struct PyBPEDecoder {}
impl PyBPEDecoder {
#[getter]
fn get_suffix(self_: PyRef<Self>) -> String {
getter!(self_, BPE, suffix.clone())
getter!(self_, BPE, suffix.to_string())
}

#[setter]
fn set_suffix(self_: PyRef<Self>, suffix: String) {
setter!(self_, BPE, suffix, suffix);
setter!(self_, BPE, suffix, suffix.into());
}

#[new]
Expand All @@ -443,22 +455,27 @@ pub struct PyCTCDecoder {}
impl PyCTCDecoder {
#[getter]
fn get_pad_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, pad_token.clone())
getter!(self_, CTC, pad_token.to_string())
}

#[setter]
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
setter!(self_, CTC, pad_token, pad_token);
setter!(self_, CTC, pad_token, pad_token.into());
}

#[getter]
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, word_delimiter_token.clone())
getter!(self_, CTC, word_delimiter_token.clone()).to_string()
}

#[setter]
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
setter!(
self_,
CTC,
word_delimiter_token,
word_delimiter_token.into()
);
}

#[getter]
Expand Down Expand Up @@ -526,22 +543,33 @@ impl CustomDecoder {
}

impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode<T: ToCompactString>(&self, tokens: Vec<T>) -> tk::Result<impl ToCompactString> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract(py)?;
.extract::<String>(py)?;
Ok(decoded)
})
}

fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
.extract(py)?;
.extract::<Vec<String>>(py)?;
Ok(decoded)
})
}
Expand Down Expand Up @@ -595,10 +623,21 @@ where
}

impl Decoder for PyDecoderWrapper {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Wrapped(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
PyDecoderWrapper::Custom(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
}
}
}
Expand Down Expand Up @@ -663,14 +702,17 @@ impl PyDecodeStream {

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
))
ToPyResult(
tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix.to_compact_string(),
&mut self.prefix_index,
)
.map(|o| o.map(|s| s.to_string())),
)
.into()
}
}
Expand Down
6 changes: 5 additions & 1 deletion bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ impl PyEncoding {
/// :obj:`List[str]`: The list of tokens
#[getter]
fn get_tokens(&self) -> Vec<String> {
self.encoding.get_tokens().to_vec()
self.encoding
.get_tokens()
.iter()
.map(|x| x.to_string())
.collect()
}

/// The generated word indices.
Expand Down
Loading