diff --git a/README.md b/README.md index a734002..2fff27e 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,31 @@ print(generated_data) } ``` +### Example Enum +```python +color = { + "type": "object", + "properties": { + "color": { + "type": "enum", + "values": [ + "black", + "red", + "white", + "green", + "blue" + ] + } + } +} +``` + +```python +{ + color: "blue" +} +``` + ## Features - Bulletproof JSON generation: Jsonformer ensures that the generated JSON is always syntactically correct and conforms to the specified schema. diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index db288d3..518e785 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -60,6 +60,30 @@ def __call__( return True return False + + +class EnumStoppingCriteria(StoppingCriteria): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + prompt_length: int, + enums + ): + self.tokenizer = tokenizer + self.prompt_length = prompt_length + self.enums = enums + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + ) -> bool: + decoded = self.tokenizer.decode( + input_ids[0][self.prompt_length :], skip_special_tokens=True + ) + + return decoded in self.enums + class OutputNumbersTokens(LogitsWarper): def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): @@ -82,3 +106,50 @@ def __call__(self, _, scores): scores[~mask] = -float("inf") return scores + + +class OutputEnumTokens(LogitsWarper): + def __init__(self, tokenizer: PreTrainedTokenizer, enums): + self.tokenizer = tokenizer + vocab_size = len(tokenizer) + self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) + self.tree = self.build_tree(enums) + self.is_first_call = True + self.vocab_size = len(tokenizer) + + def create_mask(self, allowed_tokens): + allowed_mask = torch.zeros(self.vocab_size, dtype=torch.bool) + for _, token_id in self.tokenizer.get_vocab().items(): + if token_id in allowed_tokens: + allowed_mask[token_id] = True + return allowed_mask + + def build_tree(self, enums): + tree = {} + for enum in enums: + encoded_enum = self.tokenizer.encode(enum, add_special_tokens=False) + curr_obj = tree + for code in encoded_enum: + if code in curr_obj.keys(): + curr_obj = curr_obj[code] + else: + curr_obj[code] = {} + curr_obj = curr_obj[code] + return tree + + def __call__(self, input_ids, scores): + if not self.is_first_call: + self.tree = self.tree[int(input_ids[0][-1])] + else: + self.is_first_call = False + + allowed_tokens = self.tree.keys() + + if not len(allowed_tokens): + raise Exception("Shouldn't happen") + + allowed_mask = self.create_mask(allowed_tokens) + mask = allowed_mask.expand_as(scores) + scores[~mask] = -float("inf") + return scores + diff --git a/jsonformer/main.py b/jsonformer/main.py index 9c13471..4785c15 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -4,6 +4,8 @@ NumberStoppingCriteria, OutputNumbersTokens, StringStoppingCriteria, + EnumStoppingCriteria, + OutputEnumTokens ) from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer @@ -138,6 +140,47 @@ def generate_string(self) -> str: return response return response.split('"')[0].strip() + + def generate_enum(self, values) -> str: + prompt = self.get_prompt() + self.debug("[generate_enum]", prompt, is_prompt=True) + input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + values = [f'"{value}"'if isinstance(value,str) else str(value) for value in values] + + response = self.model.generate( + input_tokens, + max_new_tokens=max([len(self.tokenizer.encode(value, add_special_tokens=False)) for value in values]), + num_return_sequences=1, + temperature=self.temperature, + logits_processor=[OutputEnumTokens(self.tokenizer, values)], + stopping_criteria=[ + EnumStoppingCriteria(self.tokenizer, len(input_tokens[0]), values) + ], + pad_token_id=self.tokenizer.eos_token_id, + ) + + # Some models output the prompt as part of the response + # This removes the prompt from the response if it is present + if ( + len(response[0]) >= len(input_tokens[0]) + and (response[0][: len(input_tokens[0])] == input_tokens).all() + ): + response = response[0][len(input_tokens[0]) :] + if response.shape[0] == 1: + response = response[0] + + response = self.tokenizer.decode(response, skip_special_tokens=True) + + self.debug("[generate_enum]", "|" + response + "|") + + if response[0] == response[-1] == '"': + return response[1:-1] + + if '.' in response: + return float(response) + return int(response) def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] @@ -146,6 +189,7 @@ def generate_object( self.debug("[generate_object] generating value for", key) obj[key] = self.generate_value(schema, obj, key) return obj + def generate_value( self, @@ -183,6 +227,12 @@ def generate_value( else: obj.append(new_obj) return self.generate_object(schema["properties"], new_obj) + elif schema_type == "enum": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_enum(schema["values"]) else: raise ValueError(f"Unsupported schema type: {schema_type}")