From 1a0bf48db82d7441505747745b790d59e72d8ff7 Mon Sep 17 00:00:00 2001 From: nameearly <2741313455@qq.com> Date: Tue, 3 Mar 2026 17:55:17 +0800 Subject: [PATCH 1/2] fix: wrong mapping between rollouts --- tinker_cookbook/recipes/system_prompt_learning_rl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tinker_cookbook/recipes/system_prompt_learning_rl.py b/tinker_cookbook/recipes/system_prompt_learning_rl.py index b1dc92e..f1dcff4 100644 --- a/tinker_cookbook/recipes/system_prompt_learning_rl.py +++ b/tinker_cookbook/recipes/system_prompt_learning_rl.py @@ -613,6 +613,7 @@ def _single_query_critique(self, problem_to_summarized_rollouts, principles, sav all_rollouts.append(rollouts) sample_futures_list = [] + rollouts_per_problem_list = [] for rollouts_per_problem in all_rollouts: problem = rollouts_per_problem[0]["problem"] answer = rollouts_per_problem[0]["groundtruth"] @@ -636,9 +637,12 @@ def _single_query_critique(self, problem_to_summarized_rollouts, principles, sav sample_futures_list.append( get_sample_future(conversation, self.renderer, self.sampling_client, self.sampling_params) ) + # IMPORTANT: keep a 1:1 mapping between each Future and its corresponding rollouts. + # Otherwise, referencing `rollouts_per_problem` later would accidentally reuse the last loop value. + rollouts_per_problem_list.append(rollouts_per_problem) results = [] - for sample_future in sample_futures_list: + for rollouts_per_problem, sample_future in zip(rollouts_per_problem_list, sample_futures_list): raw_response = get_text_from_sampled_future(sample_future, self.renderer) # parse json operations = get_operations_from_json(raw_response) From e93b2225587f1c26f6546564e035bb636423254b Mon Sep 17 00:00:00 2001 From: nameearly <2741313455@qq.com> Date: Tue, 3 Mar 2026 18:49:59 +0800 Subject: [PATCH 2/2] docs: add comment of advice --- .../recipes/system_prompt_learning_rl.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tinker_cookbook/recipes/system_prompt_learning_rl.py b/tinker_cookbook/recipes/system_prompt_learning_rl.py index f1dcff4..427e3a4 100644 --- a/tinker_cookbook/recipes/system_prompt_learning_rl.py +++ b/tinker_cookbook/recipes/system_prompt_learning_rl.py @@ -211,6 +211,10 @@ def pass_at_k_for_range(n: int, c: int): def get_operations_from_json(llm_response): + # ADVICE: This JSON extraction is intentionally simple but brittle. + # - If the model does not emit a ```json fenced block (or uses ```JSON), `split` may grab the wrong span. + # - If multiple fenced blocks exist, this will always pick the last ```json block. + # Consider using a case-insensitive regex to extract the fenced JSON, and/or fallback heuristics. response_json = llm_response.split("```json")[-1].split("```")[0] try: operations = json.loads(fix_json_backslashes(response_json)) @@ -572,6 +576,9 @@ def _single_rollout_summary(self, rollouts, save_dir): for rollouts in problems_to_rollouts.values(): scores = [each["reward"] for each in rollouts] avg_score = sum(scores) / max(len(scores), 1) + # ADVICE: This filter only keeps problems with mixed outcomes (some correct, some wrong). + # That can easily starve the updater when the model is very weak (all 0) or very strong (all 1). + # If you want learning signals from failures, consider also processing all-wrong cases. if 0 < avg_score < 1: all_rollouts_to_process.extend(rollouts) @@ -609,6 +616,9 @@ def _single_query_critique(self, problem_to_summarized_rollouts, principles, sav for rollouts in problem_to_summarized_rollouts.values(): scores = [each["reward"] for each in rollouts] avg_score = sum(scores) / len(scores) + # ADVICE: Same starvation risk as above: only mixed-success problems are critiqued. + # Consider handling avg_score==0 (all wrong) to propose new principles, and avg_score==1 (all right) + # to optionally merge/deduplicate principles. if 0 < avg_score < 1: all_rollouts.append(rollouts) @@ -671,6 +681,10 @@ def _batch_update(self, principles, critiques, save_dir, max_retries=3): candidate_principles = copy.deepcopy(principles) to_modify = [] max_ID = 0 + # ADVICE: ID-collision risk. + # `max_ID` starts at 0 and new principles are keyed as C0, C1, ... + # If `candidate_principles` already contains C* keys (from prior runs), this may overwrite existing ones. + # Safer: initialize max_ID from existing keys, e.g. max(int(k[1:]) for k in candidate_principles if k.startswith('C')) + 1. for operation in all_operations: if operation["operation"] == "modify": if operation["modified_from"] in candidate_principles: @@ -1081,6 +1095,10 @@ def load_state_dict(self, state): def main(config: Config): # Setup logging + # ADVICE: Reproducibility. + # This script uses `random` (evolution sampling, shuffling) and `numpy`/`torch` elsewhere. + # If you want runs to be more reproducible, consider adding a `seed` field to Config and setting + # random.seed(seed), np.random.seed(seed), and torch.manual_seed(seed) here. ml_logger = ml_log.setup_logging( log_dir=config.log_path, wandb_project="system_prompt_learning", @@ -1452,6 +1470,14 @@ def main(config: Config): all_tokens = prompt_tokens + sampled_tokens group_tokens.append(all_tokens) + # ADVICE: Logprob/offset alignment. + # We set ob_len = len(prompt_tokens) - 1 because (in standard next-token LM training) + # the first prompt token has no target/logprob, so prompt contributes (prompt_len-1) + # positions in the shifted (input, target) pair. + # This assumes `sampled_logprobs` has length == len(sampled_tokens) and corresponds + # to the generated tokens only (no extra dummy element). If your sampler returns + # logprobs with a leading dummy (like compute_logprobs() often does), this will + # silently misalign training. Consider asserting the expected lengths here. group_ob_lens.append(len(prompt_tokens) - 1) group_logprobs.append(sampled_logprobs) @@ -1502,6 +1528,10 @@ def main(config: Config): ): # check if all advantages are zero if all(advantage == 0.0 for advantage in advantages): + # ADVICE: This can easily drop most data early in training. + # With binary rewards (0/1) and small group_size, it's common to get all-0 or all-1 groups, + # which makes centered advantages all zero and therefore skips the entire question. + # Consider logging the skip ratio and/or using a different baseline/advantage scheme if learning stalls. # Skip question because all advantages are the same continue