diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index f738ab9bb..5e43a781f 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -24,6 +24,7 @@ ) import jinja2 +from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment import numpy as np @@ -190,6 +191,14 @@ def __call__( **kwargs: Any, ) -> ChatFormatterResponse: ... +class GenerationTagIgnore(Extension): + """Ignores the generation and endgeneration tags in Jinja templates.""" + + tags = {"generation", "endgeneration"} + + def parse(self, parser): + parser.stream.skip(1) + return nodes.Const("") class Jinja2ChatFormatter(ChatFormatter): def __init__( @@ -213,6 +222,7 @@ def __init__( loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True, + extensions=[GenerationTagIgnore] ).from_string(self.template) @staticmethod