diff --git a/rellm/rellm.py b/rellm/rellm.py index 8668515..1842495 100644 --- a/rellm/rellm.py +++ b/rellm/rellm.py @@ -40,16 +40,22 @@ def complete_re(prompt:str, pattern: regex.Pattern | List[regex.Pattern], tokeni ) new_token_ids = output_ids[0, prompt_length:] output_text = tokenizer.decode(new_token_ids, skip_special_tokens=True) - partial_completion += output_text - prompt_plus_completion = prompt_plus_completion + output_text + + for output_char in output_text: + partial_completion += output_char + prompt_plus_completion = prompt_plus_completion + output_char + + if stop_after_match: + for p in pattern: + m = p.match(partial_completion) + if m and m.start() == 0: + if debug: + print("step={} completion={}".format(gen_tokens, partial_completion)) + return m[0] + if debug: print("step={} completion={}".format(gen_tokens, partial_completion)) - if stop_after_match: - for p in pattern: - m = p.match(partial_completion) - if m and m.start() == 0: - return m[0] gen_tokens += 1 return partial_completion