Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions jsonformer/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from transformers import PreTrainedTokenizer, LogitsWarper, StoppingCriteria
import torch


class StringStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer: PreTrainedTokenizer, prompt_length: int):
self.tokenizer = tokenizer
Expand Down Expand Up @@ -61,6 +62,7 @@ def __call__(

return False


class OutputNumbersTokens(LogitsWarper):
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
self.tokenizer = tokenizer
Expand All @@ -82,3 +84,23 @@ def __call__(self, _, scores):
scores[~mask] = -float("inf")

return scores


class OutputCommaAndBracketTokens(LogitsWarper):
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
self.tokenizer = tokenizer
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
vocab_size = len(tokenizer)
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)

for _, token_id in tokenizer.get_vocab().items():
token_str = tokenizer.decode(token_id).strip()

if token_str in [",", "]"]:
self.allowed_mask[token_id] = True

def __call__(self, _, scores):
mask = self.allowed_mask.expand_as(scores)
scores[~mask] = -float("inf")

return scores
42 changes: 20 additions & 22 deletions jsonformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NumberStoppingCriteria,
OutputNumbersTokens,
StringStoppingCriteria,
OutputCommaAndBracketTokens,
)
from termcolor import cprint
from transformers import PreTrainedModel, PreTrainedTokenizer
Expand Down Expand Up @@ -34,6 +35,9 @@ def __init__(
self.prompt = prompt

self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt)
self.array_end_logit_processor = OutputCommaAndBracketTokens(
self.tokenizer, self.prompt
)

self.generation_marker = "|GENERATION|"
self.debug_on = debug
Expand Down Expand Up @@ -80,7 +84,9 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0):
if iterations > 3:
raise ValueError("Failed to generate a valid number")

return self.generate_number(temperature=self.temperature * 1.3, iterations=iterations+1)
return self.generate_number(
temperature=self.temperature * 1.3, iterations=iterations + 1
)

def generate_boolean(self) -> bool:
prompt = self.get_prompt()
Expand Down Expand Up @@ -195,27 +201,19 @@ def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> li
obj.append(self.generation_marker)
input_prompt = self.get_prompt()
obj.pop()
input_tensor = self.tokenizer.encode(input_prompt, return_tensors="pt")
output = self.model.forward(input_tensor.to(self.model.device))
logits = output.logits[0, -1]


top_indices = logits.topk(30).indices
sorted_token_ids = top_indices[logits[top_indices].argsort(descending=True)]

found_comma = False
found_close_bracket = False

for token_id in sorted_token_ids:
decoded_token = self.tokenizer.decode(token_id)
if ',' in decoded_token:
found_comma = True
break
if ']' in decoded_token:
found_close_bracket = True
break

if found_close_bracket or not found_comma:
input_tokens = self.tokenizer.encode(input_prompt, return_tensors="pt").to(
self.model.device
)
response = self.model.generate(
input_tokens,
max_new_tokens=1,
num_return_sequences=1,
logits_processor=[self.array_end_logit_processor],
pad_token_id=self.tokenizer.eos_token_id,
)
last_token = self.tokenizer.decode(response[0][-1])

if "]" in last_token:
break

return obj
Expand Down