Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4897,7 +4897,10 @@ def prediction_step(
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device)
loss, outputs = self.compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
loss = loss.detach().mean()

if isinstance(outputs, dict):
Expand Down Expand Up @@ -5586,21 +5589,16 @@ def _fsdp_qlora_plugin_updates(self):
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)

def get_batch_samples(
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
) -> tuple[list, Optional[Union[torch.Tensor, int]]]:
def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> int | None:
"""
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
Counts the number of items in the batches to properly scale the loss.
Args:
batch_samples (`list`): List of batches
device (`torch.device`): The device on which the number of items in the batch should be.
Returns:
None if the number of items in the batch doesn't need to be computed else the number of items in the batch
"""
batch_samples = []
num_items_in_batch = None

for _ in range(num_batches):
try:
batch_samples.append(next(epoch_iterator))
except StopIteration:
break

count_num_items_in_batch = (
len(batch_samples) > 0
and "labels" in batch_samples[0]
Expand All @@ -5615,7 +5613,6 @@ def get_batch_samples(
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
)
)

if count_num_items_in_batch:
# For now we don't support object detection
try:
Expand All @@ -5641,6 +5638,23 @@ def get_batch_samples(
if pc := getattr(self.accelerator, "parallelism_config", None):
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size

return num_items_in_batch

def get_batch_samples(
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
) -> tuple[list, Optional[Union[torch.Tensor, int]]]:
"""
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
"""
batch_samples = []

for _ in range(num_batches):
try:
batch_samples.append(next(epoch_iterator))
except StopIteration:
break

num_items_in_batch = self._get_num_items_in_batch(batch_samples, device)
return batch_samples, num_items_in_batch

def set_initial_training_values(
Expand Down
45 changes: 45 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2882,6 +2882,9 @@ def test_evaluate_with_jit(self):
trainer = get_regression_trainer(
a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir
)
# Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
# since it's not in the model forward's signature when using JIT
trainer.model_accepts_loss_kwargs = False
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Expand All @@ -2895,6 +2898,7 @@ def test_evaluate_with_jit(self):
trainer = get_regression_trainer(
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True, output_dir=tmp_dir
)
trainer.model_accepts_loss_kwargs = False
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Expand All @@ -2913,6 +2917,7 @@ def test_evaluate_with_jit(self):
jit_mode_eval=True,
output_dir=tmp_dir,
)
trainer.model_accepts_loss_kwargs = False
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Expand Down Expand Up @@ -2957,6 +2962,40 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_train_and_predict_loss_parity(self):
"""
Tests that the loss computed during a training_step is the same as the one computed during prediction_step.
for the same inputs
"""
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
# Create a dummy batch of inputs
inputs = {}
inputs["input_ids"] = []
for row_ind in range(4):
seq_len = torch.randint(32, 64, (1,)).item()
x = torch.randint(1, 100, (seq_len,))
inputs["input_ids"].append(x)
inputs["input_ids"] = torch.nn.utils.rnn.pad_sequence(inputs["input_ids"], batch_first=True, padding_value=0)
inputs["labels"] = inputs["input_ids"].clone()
inputs["labels"][inputs["input_ids"] == 0] = -100
num_items_in_batch = inputs["labels"].ne(-100).sum().item()

def custom_loss_func(outputs, labels, num_items_in_batch=None):
logits = outputs["logits"]
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
if num_items_in_batch is not None:
return loss / num_items_in_batch # multiply by number of items to get the sum
return loss

trainer = Trainer(model, train_dataset=None, compute_loss_func=custom_loss_func)

# creating log history of trainer, results don't matter
train_loss = trainer.training_step(model, inputs, num_items_in_batch)
predict_loss = trainer.prediction_step(model, inputs, prediction_loss_only=True)[0]

torch.testing.assert_close(train_loss, predict_loss, atol=1e-6, rtol=0)

def test_predict_with_batch_eval_metrics(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
Expand Down Expand Up @@ -3024,18 +3063,23 @@ def test_predict_with_batch_eval_metrics(self):
def test_predict_with_jit(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True, output_dir=tmp_dir)
# Make sure the trainer doesn't pass num_items_in_batch to the model's forward method,
# since it's not in the model forward's signature when using JIT
trainer.model_accepts_loss_kwargs = False
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True, output_dir=tmp_dir)
trainer.model_accepts_loss_kwargs = False
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

# With more than one output of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True, output_dir=tmp_dir)
trainer.model_accepts_loss_kwargs = False
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
Expand All @@ -3051,6 +3095,7 @@ def test_predict_with_jit(self):
jit_mode_eval=True,
output_dir=tmp_dir,
)
trainer.model_accepts_loss_kwargs = False
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
Expand Down