diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 65e29b4b..382458ca 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -160,7 +160,18 @@ def preprocess_dataset( entry["step_index"] = entry["metadata"]["step_index"] if not isinstance(tokenizer.eos_token_id, int): raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id") - dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + try: + dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + except Exception as e: + logger.error(f"Error in populate_rl_data: {e}") + logger.error(f"Data: {data}") + logger.error(f"Dataset: {dataset}") + logger.error(f"Tokenizer: {tokenizer}") + logger.error(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}") + logger.error(f"RL config: {rl_config}") + logger.error(f"LLM: {llm}") + logger.error(f"Seq length: {seq_length}") + raise e return dataset @@ -533,7 +544,7 @@ def run_preprocessing_loop( while len(buffer) > 0: if len(processed_entries_queue) == processed_entries_queue.maxlen: if not pop_old_data: - break + break else: processed_entries_queue_popped_data += 1 if processed_entries_queue_popped_data % 100 == 0 and last_time_notice != processed_entries_queue_popped_data // 100: @@ -573,6 +584,10 @@ def run_preprocessing_loop( sample_length = len(entry["input_ids"]) if current_length + sample_length > cfg.finetune.seq_length: + if len(current_batch) == 0: + raise ValueError( + f"sample_length is {sample_length}, but cfg.finetune.seq_length is {cfg.finetune.seq_length}" + ) time_to_write = True break # Current micro batch is full