From 4c6b8cb722eefd073e2e1025b5f2141b4b995041 Mon Sep 17 00:00:00 2001 From: mattiasarro Date: Tue, 30 May 2023 22:10:11 +0300 Subject: [PATCH 1/2] ReTokenFilter.is_valid_token: partial=False Matching a regex partially can lead to generating a token which causes the whole generated sequence to be invalid. --- rellm/re_token_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rellm/re_token_filter.py b/rellm/re_token_filter.py index ba48c39..2dd525b 100644 --- a/rellm/re_token_filter.py +++ b/rellm/re_token_filter.py @@ -16,7 +16,7 @@ def build_decoded_tokens_cache(tokenizer: PreTrainedTokenizer) -> Dict[int, str] def is_valid_token(self, token_id: int, partial_completion: str, patterns: List[regex.Pattern]) -> bool: decoded_token = self.decoded_tokens_cache[token_id] - return any(pattern.fullmatch(partial_completion + decoded_token, partial=True) for pattern in patterns) + return any(pattern.fullmatch(partial_completion + decoded_token) for pattern in patterns) def filter_tokens(self, partial_completion: str, patterns: Union[regex.Pattern, List[regex.Pattern]]) -> Set[int]: if isinstance(patterns, regex.Pattern): From a9da7529fcd952b5ca8fab53548abe775c165b21 Mon Sep 17 00:00:00 2001 From: mattiasarro Date: Wed, 31 May 2023 10:39:13 +0300 Subject: [PATCH 2/2] fix handling of partial tokens When using partial=True, we ensure we don't generate invalid output, but also this makes it impossible to generate certain output sequences. Therefore it's necessary to allow generating tokens which match only partially, and then take the substring of that token which matches the regex. --- rellm/re_token_filter.py | 2 +- rellm/rellm.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/rellm/re_token_filter.py b/rellm/re_token_filter.py index 2dd525b..ba48c39 100644 --- a/rellm/re_token_filter.py +++ b/rellm/re_token_filter.py @@ -16,7 +16,7 @@ def build_decoded_tokens_cache(tokenizer: PreTrainedTokenizer) -> Dict[int, str] def is_valid_token(self, token_id: int, partial_completion: str, patterns: List[regex.Pattern]) -> bool: decoded_token = self.decoded_tokens_cache[token_id] - return any(pattern.fullmatch(partial_completion + decoded_token) for pattern in patterns) + return any(pattern.fullmatch(partial_completion + decoded_token, partial=True) for pattern in patterns) def filter_tokens(self, partial_completion: str, patterns: Union[regex.Pattern, List[regex.Pattern]]) -> Set[int]: if isinstance(patterns, regex.Pattern): diff --git a/rellm/rellm.py b/rellm/rellm.py index 8668515..1842495 100644 --- a/rellm/rellm.py +++ b/rellm/rellm.py @@ -40,16 +40,22 @@ def complete_re(prompt:str, pattern: regex.Pattern | List[regex.Pattern], tokeni ) new_token_ids = output_ids[0, prompt_length:] output_text = tokenizer.decode(new_token_ids, skip_special_tokens=True) - partial_completion += output_text - prompt_plus_completion = prompt_plus_completion + output_text + + for output_char in output_text: + partial_completion += output_char + prompt_plus_completion = prompt_plus_completion + output_char + + if stop_after_match: + for p in pattern: + m = p.match(partial_completion) + if m and m.start() == 0: + if debug: + print("step={} completion={}".format(gen_tokens, partial_completion)) + return m[0] + if debug: print("step={} completion={}".format(gen_tokens, partial_completion)) - if stop_after_match: - for p in pattern: - m = p.match(partial_completion) - if m and m.start() == 0: - return m[0] gen_tokens += 1 return partial_completion