Skip to content

Commit 326eb75

Browse files
Abdennacer-Badaouiremi-or
authored andcommitted
Fix T5 tests: use generation_config for generation parameters (#42419)
* pass the generation parameters to generate() * fix use_task_specific_params to separate model.config and model.generation_config params * fix style * some fixes * remove redundant check * update expectation for llama_7b_bf16 on rocm * Update tests/models/llama/test_modeling_llama.py Co-authored-by: Rémi Ouazan <83456801+remi-or@users.noreply.github.com> --------- Co-authored-by: Rémi Ouazan <83456801+remi-or@users.noreply.github.com>
1 parent ec3f555 commit 326eb75

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

tests/models/llama/test_modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_model_7b_logits_bf16(self):
117117
("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
118118
("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
119119
("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
120-
("rocm", (9, 4)): torch.tensor([[-6.5094, -4.1329, -4.9754, -3.5042, 0.8082, -2.9443, 1.2830, -3.3539]]),
120+
("rocm", (9, 4)): torch.tensor([[-6.5067, -4.1154, -4.9819, -3.1408, 0.8117, -2.9435, 1.2883, -3.3221]]),
121121
})
122122

123123
expected_mean = expected_means.get_expectation().to(torch_device)

tests/models/t5/test_modeling_t5.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from transformers import (
4848
AutoTokenizer,
4949
ByT5Tokenizer,
50+
GenerationConfig,
5051
T5EncoderModel,
5152
T5ForConditionalGeneration,
5253
T5ForQuestionAnswering,
@@ -932,7 +933,17 @@ def is_pipeline_test_to_skip(
932933

933934

934935
def use_task_specific_params(model, task):
935-
model.config.update(model.config.task_specific_params[task])
936+
task_params = model.config.task_specific_params[task]
937+
938+
# Get all valid GenerationConfig attributes
939+
temp_config = GenerationConfig()
940+
generation_config_attrs = set(temp_config.to_dict().keys())
941+
942+
for key, value in task_params.items():
943+
if key in generation_config_attrs:
944+
setattr(model.generation_config, key, value)
945+
else:
946+
setattr(model.config, key, value)
936947

937948

938949
@require_torch
@@ -1032,14 +1043,11 @@ def test_torch_quant(self):
10321043
@slow
10331044
def test_small_generation(self):
10341045
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device)
1035-
model.config.max_length = 8
1036-
model.config.num_beams = 1
1037-
model.config.do_sample = False
10381046
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
10391047

10401048
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device)
10411049

1042-
sequences = model.generate(input_ids)
1050+
sequences = model.generate(input_ids, max_length=8, num_beams=1, do_sample=False)
10431051

10441052
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
10451053
self.assertTrue(output_str == "Hello there!")

0 commit comments

Comments
 (0)