Skip to content

Commit fdd831d

Browse files
committed
Raise error in the index init if regex and vocab are incompatible
1 parent 4f92f5c commit fdd831d

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ pub enum Error {
6969
InvalidRefecencePath(Box<str>),
7070
#[error("Ref recusion limit reached: {0}")]
7171
RefRecursionLimitReached(usize),
72+
#[error("The vocabulary provided is incompatible with the regex '{regex}'. Found no transitions from state {error_state}, missing tokens corresponding to at least one of the following characters: {missing_tokens:?}. This may be due to an encoding issue in your vocabulary.")]
73+
IncompatibleVocabulary {
74+
regex: String,
75+
error_state: u32,
76+
missing_tokens: Vec<String>,
77+
},
7278
}
7379

7480
impl Error {

src/index.rs

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ impl Index {
116116
let mut next_states: Vec<AutomataStateId> = vec![start_state];
117117

118118
while let Some(current_state) = next_states.pop() {
119+
let mut has_valid_transitions = false;
120+
119121
if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
120122
final_states.insert(current_state.as_u32());
123+
has_valid_transitions = true;
121124
}
122125

123126
'token_loop: for (token, ids) in vocabulary.tokens().iter() {
@@ -136,6 +139,7 @@ impl Index {
136139
let is_intermediate_state = !dfa.is_match_state(next_state);
137140
let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
138141
if is_intermediate_state || is_full_match_state {
142+
has_valid_transitions = true;
139143
for token_id in ids {
140144
transitions
141145
.entry(current_state.as_u32())
@@ -148,6 +152,28 @@ impl Index {
148152
next_states.push(next_state);
149153
}
150154
}
155+
156+
// If the current state has no valid transitions and is not a match state,
157+
// it means the vocabulary is incompatible with the regex.
158+
if !has_valid_transitions && !dfa.is_match_state(current_state) {
159+
let mut valid_characters = Vec::new();
160+
for byte in 0..=255u8 {
161+
let test_state = dfa.next_state(current_state, byte);
162+
if !dfa.is_dead_state(test_state) && !dfa.is_quit_state(test_state) {
163+
if byte.is_ascii() {
164+
valid_characters.push(char::from(byte).to_string());
165+
} else {
166+
valid_characters.push(format!("\\x{:02x}", byte));
167+
}
168+
}
169+
}
170+
171+
return Err(Error::IncompatibleVocabulary {
172+
regex: regex.to_string(),
173+
error_state: current_state.as_u32(),
174+
missing_tokens: valid_characters,
175+
});
176+
}
151177
}
152178

153179
// Populate `transitions` with mappings from `final_states` to `eos_token_id`
@@ -290,7 +316,7 @@ mod tests {
290316
.expect("Insert failed");
291317
}
292318
for (token, token_id) in [
293-
(vec![32, 240, 159, 152], 7),
319+
(vec![32, 240, 159, 152, 136], 7),
294320
(vec![32, 240, 159, 152, 141], 6),
295321
(vec![240, 159, 152, 141], 4),
296322
] {
@@ -309,10 +335,60 @@ mod tests {
309335
),
310336
(
311337
80,
312-
HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
338+
HashMap::from_iter([(2, 128), (7, 208), (5, 208), (6, 208)]),
313339
),
314340
(128, HashMap::from_iter([(8, 128)])),
315341
]);
316342
assert_eq!(index.transitions(), &expected);
317343
}
344+
345+
#[test]
346+
fn index_incompatible_vocabulary_error() {
347+
let regex = "0 1";
348+
let mut vocabulary = Vocabulary::new(3);
349+
for (token, token_id) in [("0", 0), ("0 ", 1), ("1", 2)] {
350+
vocabulary
351+
.try_insert(token, token_id as u32)
352+
.expect("Insert failed");
353+
}
354+
355+
let result = Index::new(regex, &vocabulary);
356+
assert!(result.is_err());
357+
358+
if let Err(Error::IncompatibleVocabulary {
359+
regex: _,
360+
missing_tokens,
361+
..
362+
}) = result
363+
{
364+
assert!(missing_tokens.contains(&" ".to_string()));
365+
} else {
366+
panic!("Expected IncompatibleVocabulary error");
367+
}
368+
}
369+
370+
#[test]
371+
fn index_incompatible_vocabulary_error_non_ascii() {
372+
let regex = "😈😍";
373+
let mut vocabulary = Vocabulary::new(3);
374+
for (token, token_id) in [("😈", 0), (" ", 1), ("b", 2)] {
375+
vocabulary
376+
.try_insert(token, token_id as u32)
377+
.expect("Insert failed");
378+
}
379+
380+
let result = Index::new(regex, &vocabulary);
381+
assert!(result.is_err());
382+
383+
if let Err(Error::IncompatibleVocabulary {
384+
regex: _,
385+
missing_tokens,
386+
..
387+
}) = result
388+
{
389+
assert!(missing_tokens.contains(&"\\xf0".to_string()));
390+
} else {
391+
panic!("Expected IncompatibleVocabulary error");
392+
}
393+
}
318394
}

0 commit comments

Comments
 (0)