From c6f53b6a973dc4749b2dd1dbd5d24731b2dd2a79 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Fri, 26 Sep 2025 16:47:57 +0000 Subject: [PATCH 1/4] Add num_items_in_batch computation to predict_step. --- src/transformers/trainer.py | 42 +++++++++++++++++++++++------------ tests/trainer/test_trainer.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 76d36327b308..92175d2be03a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4897,7 +4897,10 @@ def prediction_step( else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): - loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + num_items_in_batch = self._get_num_items_in_batch([inputs], inputs["input_ids"].device) + loss, outputs = self.compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) loss = loss.detach().mean() if isinstance(outputs, dict): @@ -5586,21 +5589,16 @@ def _fsdp_qlora_plugin_updates(self): self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True ) - def get_batch_samples( - self, epoch_iterator: Iterator, num_batches: int, device: torch.device - ) -> tuple[list, Optional[Union[torch.Tensor, int]]]: + def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> int | None: """ - Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss. + Counts the number of items in the batches to properly scale the loss. + Args: + batch_samples (`list`): List of batches + device (`torch.device`): The device on which the number of items in the batch should be. + Returns: + None if the number of items in the batch doesn't need to be computed else the number of items in the batch """ - batch_samples = [] num_items_in_batch = None - - for _ in range(num_batches): - try: - batch_samples.append(next(epoch_iterator)) - except StopIteration: - break - count_num_items_in_batch = ( len(batch_samples) > 0 and "labels" in batch_samples[0] @@ -5615,7 +5613,6 @@ def get_batch_samples( # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790 ) ) - if count_num_items_in_batch: # For now we don't support object detection try: @@ -5641,6 +5638,23 @@ def get_batch_samples( if pc := getattr(self.accelerator, "parallelism_config", None): num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size + return num_items_in_batch + + def get_batch_samples( + self, epoch_iterator: Iterator, num_batches: int, device: torch.device + ) -> tuple[list, Optional[Union[torch.Tensor, int]]]: + """ + Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss. + """ + batch_samples = [] + + for _ in range(num_batches): + try: + batch_samples.append(next(epoch_iterator)) + except StopIteration: + break + + num_items_in_batch = self._get_num_items_in_batch(batch_samples, device) return batch_samples, num_items_in_batch def set_initial_training_values( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5cce980a6a00..7b2a47f7f2b4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2957,6 +2957,40 @@ def test_predict(self): self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_train_and_predict_loss_parity(self): + """ + Tests that the loss computed during a training_step is the same as the one computed during prediction_step. + for the same inputs + """ + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + # Create a dummy batch of inputs + inputs = {} + inputs["input_ids"] = [] + for row_ind in range(4): + seq_len = torch.randint(32, 64, (1,)).item() + x = torch.randint(1, 100, (seq_len,)) + inputs["input_ids"].append(x) + inputs["input_ids"] = torch.nn.utils.rnn.pad_sequence(inputs["input_ids"], batch_first=True, padding_value=0) + inputs["labels"] = inputs["input_ids"].clone() + inputs["labels"][inputs["input_ids"] == 0] = -100 + num_items_in_batch = inputs["labels"].ne(-100).sum().item() + + def custom_loss_func(outputs, labels, num_items_in_batch=None): + logits = outputs["logits"] + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + if num_items_in_batch is not None: + return loss / num_items_in_batch # multiply by number of items to get the sum + return loss + + trainer = Trainer(model, train_dataset=None, compute_loss_func=custom_loss_func) + + # creating log history of trainer, results don't matter + train_loss = trainer.training_step(model, inputs, num_items_in_batch) + predict_loss = trainer.prediction_step(model, inputs, prediction_loss_only=True)[0] + + torch.testing.assert_close(train_loss, predict_loss, atol=1e-6, rtol=0) + def test_predict_with_batch_eval_metrics(self): with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer( From fd28064a655de3bc1f7f2f832f71f46cd55b4e1a Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Mon, 29 Sep 2025 13:25:17 +0000 Subject: [PATCH 2/4] address comments. --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 92175d2be03a..b1ade83e3e02 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4897,7 +4897,7 @@ def prediction_step( else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): - num_items_in_batch = self._get_num_items_in_batch([inputs], inputs["input_ids"].device) + num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device) loss, outputs = self.compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) From aa238246a08c4b2a0b42a0a48f9d13a79aa0c0b4 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:55:26 +0000 Subject: [PATCH 3/4] Fix test cases. --- tests/trainer/test_trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7b2a47f7f2b4..2d0f8eb6f074 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2882,6 +2882,9 @@ def test_evaluate_with_jit(self): trainer = get_regression_trainer( a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir ) + # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, + # since it's not in the model forward's signature when using JIT + trainer.model_accepts_loss_kwargs = False results = trainer.evaluate() x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] @@ -2895,6 +2898,7 @@ def test_evaluate_with_jit(self): trainer = get_regression_trainer( a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir ) + trainer.model_accepts_loss_kwargs = False results = trainer.evaluate() x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] @@ -2913,6 +2917,7 @@ def test_evaluate_with_jit(self): jit_mode_eval=True, output_dir=tmp_dir, ) + trainer.model_accepts_loss_kwargs = False results = trainer.evaluate() x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] @@ -3058,18 +3063,23 @@ def test_predict_with_batch_eval_metrics(self): def test_predict_with_jit(self): with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True, output_dir=tmp_dir) + # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, + # since it's not in the model forward's signature when using JIT + trainer.model_accepts_loss_kwargs = False preds = trainer.predict(trainer.eval_dataset).predictions x = trainer.eval_dataset.x self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) # With a number of elements not a round multiple of the batch size trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True, output_dir=tmp_dir) + trainer.model_accepts_loss_kwargs = False preds = trainer.predict(trainer.eval_dataset).predictions x = trainer.eval_dataset.x self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) # With more than one output of the model trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True, output_dir=tmp_dir) + trainer.model_accepts_loss_kwargs = False preds = trainer.predict(trainer.eval_dataset).predictions x = trainer.eval_dataset.x self.assertEqual(len(preds), 2) @@ -3085,6 +3095,7 @@ def test_predict_with_jit(self): jit_mode_eval=True, output_dir=tmp_dir, ) + trainer.model_accepts_loss_kwargs = False outputs = trainer.predict(trainer.eval_dataset) preds = outputs.predictions labels = outputs.label_ids From 753d8d7b8bbccbe20c0cba79c4ada19ab2100f3f Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:58:57 +0000 Subject: [PATCH 4/4] fixup --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d0f8eb6f074..afef6f23c987 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2882,7 +2882,7 @@ def test_evaluate_with_jit(self): trainer = get_regression_trainer( a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir ) - # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, + # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, # since it's not in the model forward's signature when using JIT trainer.model_accepts_loss_kwargs = False results = trainer.evaluate() @@ -3063,7 +3063,7 @@ def test_predict_with_batch_eval_metrics(self): def test_predict_with_jit(self): with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True, output_dir=tmp_dir) - # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, + # Make sure the trainer doesn't pass num_items_in_batch to the model's forward method, # since it's not in the model forward's signature when using JIT trainer.model_accepts_loss_kwargs = False preds = trainer.predict(trainer.eval_dataset).predictions