Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions syncode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import syncode.common as common
from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel
from syncode.parsers.grammars import Grammar
from syncode.utils.generation import filter_code, fix_indents
from typing import Callable, Iterable, Union
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
super().__init__()

self.prompt_template = prompt_template
self.model = model
self.model: PreTrainedModel = model
self.tokenizer = tokenizer
self.device = device
self.best_of = best_of
Expand Down Expand Up @@ -193,7 +193,9 @@ def _generate(

# This does not include grammar decoder
self.model._prepare_special_tokens(gen_config, False, device=self.device)
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])

# Add logits processor for generation parameters such as top_k, top_p, temperature, etc.
logits_processor = self.model._get_logits_warper(gen_config, self.device)

max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
Expand Down