Skip to content

Commit af4a0ec

Browse files
committed
Backtrack tmp.
1 parent 0cddfce commit af4a0ec

File tree

7 files changed

+821
-3
lines changed

7 files changed

+821
-3
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import argparse
3+
import datetime
4+
from datasets import load_dataset
5+
from tokenizers import Tokenizer
6+
from typing import Tuple
7+
8+
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B"
9+
DATASET = "facebook/xnli"
10+
DATASET_CONFIG = "all_languages"
11+
DEFAULT_THREADS = [2**i for i in range(8) if 2**i <= os.cpu_count()]
12+
13+
14+
def format_byte_size(num_bytes: int) -> Tuple[str, str]:
15+
"""Convert bytes to a human-readable format (KB, MB, GB)."""
16+
num_bytes_f = float(num_bytes)
17+
for unit in ["B", "KB", "MB", "GB", "TB"]:
18+
if num_bytes_f < 1024:
19+
return f"{num_bytes_f:.2f} {unit}", unit
20+
num_bytes_f /= 1024
21+
return f"{num_bytes_f:.2f} PB", "PB"
22+
23+
24+
def test(model: str, dataset: str, dataset_config: str):
25+
dataset_xnli = load_dataset(dataset, dataset_config)
26+
tokenizer = Tokenizer.from_pretrained(model)
27+
tokenizer2 = Tokenizer.from_pretrained(model)
28+
tokenizer2.enable_backtrack()
29+
30+
for easy in ["1880", " cream"]:
31+
encoded = tokenizer.encode(easy)
32+
encoded2 = tokenizer2.encode(easy)
33+
if encoded.ids != encoded2.ids:
34+
import ipdb
35+
36+
ipdb.set_trace()
37+
assert encoded.ids == encoded2.ids
38+
39+
sentences = []
40+
en_sentences = []
41+
for _i, item in enumerate(dataset_xnli["train"]):
42+
# sentence = item["premise"]["en"]
43+
# sentences.append(sentence)
44+
for lang, sentence in item["premise"].items():
45+
if lang == "en":
46+
en_sentences.append(sentence)
47+
sentences.append(sentence)
48+
sentences = en_sentences + sentences
49+
50+
start = datetime.datetime.now()
51+
encoded = tokenizer.encode_batch_fast(sentences)
52+
print(f"Took {datetime.datetime.now() - start}")
53+
54+
start = datetime.datetime.now()
55+
encoded2 = tokenizer2.encode_batch_fast(sentences)
56+
print(f"Took {datetime.datetime.now() - start}")
57+
58+
assert len(encoded) == len(encoded2)
59+
assert len(encoded) == len(sentences)
60+
total = 0
61+
correct = 0
62+
for enc, enc2, sentence in zip(encoded, encoded2, sentences):
63+
# if enc.ids != enc2.ids:
64+
# print(enc.ids)
65+
# print(enc2.ids)
66+
if enc.ids == enc2.ids:
67+
correct += 1
68+
total += 1
69+
assert enc.ids == enc2.ids, f"{enc.ids} != {enc2.ids} (Source: {sentence}"
70+
print(f"{correct} / {total} ({correct / total * 100:.2f}%%)")
71+
# print("All good !")
72+
73+
74+
def main():
75+
parser = argparse.ArgumentParser(
76+
prog="bench_tokenizer",
77+
description="Getting a feel for speed when tokenizing",
78+
)
79+
parser.add_argument("-m", "--model", default=MODEL_ID, type=str)
80+
parser.add_argument("-d", "--dataset", default=DATASET, type=str)
81+
parser.add_argument("-ds", "--dataset-config", default=DATASET_CONFIG, type=str)
82+
args = parser.parse_args()
83+
test(args.model, args.dataset, args.dataset_config)
84+
85+
86+
# Call the function to run the benchmark
87+
if __name__ == "__main__":
88+
main()

bindings/python/src/tokenizer.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use serde::Serialize;
22
use std::collections::{hash_map::DefaultHasher, HashMap};
33
use std::hash::{Hash, Hasher};
4+
use tk::pre_tokenizers::byte_level::ByteLevel;
5+
use tk::ModelWrapper;
46

57
use numpy::{npyffi, PyArray1, PyArrayMethods};
68
use pyo3::class::basic::CompareOp;
@@ -1118,6 +1120,19 @@ impl PyTokenizer {
11181120
.into()
11191121
})
11201122
}
1123+
///
1124+
#[pyo3(signature = ())]
1125+
#[pyo3(text_signature = "(self)")]
1126+
fn enable_backtrack(&mut self) -> PyResult<()> {
1127+
// self.tokenizer.with_pre_tokenizer(None::<ByteLevel>);
1128+
let model = self.tokenizer.get_model();
1129+
let mut model = model.model.write().unwrap();
1130+
let ModelWrapper::BPE(ref mut model) = *model else {
1131+
todo!();
1132+
};
1133+
model.enable_backtrack();
1134+
Ok(())
1135+
}
11211136

11221137
/// Decode the given list of ids back to a string
11231138
///

tokenizers/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ monostate = "0.1.12"
6969
ahash = { version = "0.8.11", features = ["serde"] }
7070
dary_heap = { version = "0.3.6", features = ["serde"] }
7171
compact_str = { version = "0.9", features = ["serde"] }
72+
fnv = "1.0.7"
73+
aneubeck-daachorse = "1.1.1"
7274

7375
[features]
7476
default = ["progressbar", "onig", "esaxx_fast"]

0 commit comments

Comments
 (0)