diff --git a/syncode/language_model.py b/syncode/language_model.py index 8b687b8f..4b1607c2 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -3,7 +3,7 @@ import torch import syncode.common as common from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor -from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria +from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel from syncode.parsers.grammars import Grammar from syncode.utils.generation import filter_code, fix_indents from typing import Callable, Iterable, Union @@ -48,7 +48,7 @@ def __init__( super().__init__() self.prompt_template = prompt_template - self.model = model + self.model: PreTrainedModel = model self.tokenizer = tokenizer self.device = device self.best_of = best_of @@ -193,7 +193,9 @@ def _generate( # This does not include grammar decoder self.model._prepare_special_tokens(gen_config, False, device=self.device) - logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[]) + + # Add logits processor for generation parameters such as top_k, top_p, temperature, etc. + logits_processor = self.model._get_logits_warper(gen_config, self.device) max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1) self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id