From cb67a7a172649b49c3166628d8e623c8d6e6f870 Mon Sep 17 00:00:00 2001 From: botka1998 Date: Sat, 27 Jan 2024 19:07:57 +0100 Subject: [PATCH] fix array generation stopping criteria --- jsonformer/logits_processors.py | 22 +++++++++++++++++ jsonformer/main.py | 42 ++++++++++++++++----------------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index db288d3..d1ff0bb 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -2,6 +2,7 @@ from transformers import PreTrainedTokenizer, LogitsWarper, StoppingCriteria import torch + class StringStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer: PreTrainedTokenizer, prompt_length: int): self.tokenizer = tokenizer @@ -61,6 +62,7 @@ def __call__( return False + class OutputNumbersTokens(LogitsWarper): def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): self.tokenizer = tokenizer @@ -82,3 +84,23 @@ def __call__(self, _, scores): scores[~mask] = -float("inf") return scores + + +class OutputCommaAndBracketTokens(LogitsWarper): + def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): + self.tokenizer = tokenizer + self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") + vocab_size = len(tokenizer) + self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) + + for _, token_id in tokenizer.get_vocab().items(): + token_str = tokenizer.decode(token_id).strip() + + if token_str in [",", "]"]: + self.allowed_mask[token_id] = True + + def __call__(self, _, scores): + mask = self.allowed_mask.expand_as(scores) + scores[~mask] = -float("inf") + + return scores diff --git a/jsonformer/main.py b/jsonformer/main.py index 9c13471..97d3d02 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -4,6 +4,7 @@ NumberStoppingCriteria, OutputNumbersTokens, StringStoppingCriteria, + OutputCommaAndBracketTokens, ) from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer @@ -34,6 +35,9 @@ def __init__( self.prompt = prompt self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt) + self.array_end_logit_processor = OutputCommaAndBracketTokens( + self.tokenizer, self.prompt + ) self.generation_marker = "|GENERATION|" self.debug_on = debug @@ -80,7 +84,9 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0): if iterations > 3: raise ValueError("Failed to generate a valid number") - return self.generate_number(temperature=self.temperature * 1.3, iterations=iterations+1) + return self.generate_number( + temperature=self.temperature * 1.3, iterations=iterations + 1 + ) def generate_boolean(self) -> bool: prompt = self.get_prompt() @@ -195,27 +201,19 @@ def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> li obj.append(self.generation_marker) input_prompt = self.get_prompt() obj.pop() - input_tensor = self.tokenizer.encode(input_prompt, return_tensors="pt") - output = self.model.forward(input_tensor.to(self.model.device)) - logits = output.logits[0, -1] - - - top_indices = logits.topk(30).indices - sorted_token_ids = top_indices[logits[top_indices].argsort(descending=True)] - - found_comma = False - found_close_bracket = False - - for token_id in sorted_token_ids: - decoded_token = self.tokenizer.decode(token_id) - if ',' in decoded_token: - found_comma = True - break - if ']' in decoded_token: - found_close_bracket = True - break - - if found_close_bracket or not found_comma: + input_tokens = self.tokenizer.encode(input_prompt, return_tensors="pt").to( + self.model.device + ) + response = self.model.generate( + input_tokens, + max_new_tokens=1, + num_return_sequences=1, + logits_processor=[self.array_end_logit_processor], + pad_token_id=self.tokenizer.eos_token_id, + ) + last_token = self.tokenizer.decode(response[0][-1]) + + if "]" in last_token: break return obj