Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tinker_cookbook/recipes/system_prompt_learning_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def pass_at_k_for_range(n: int, c: int):


def get_operations_from_json(llm_response):
# ADVICE: This JSON extraction is intentionally simple but brittle.
# - If the model does not emit a ```json fenced block (or uses ```JSON), `split` may grab the wrong span.
# - If multiple fenced blocks exist, this will always pick the last ```json block.
# Consider using a case-insensitive regex to extract the fenced JSON, and/or fallback heuristics.
response_json = llm_response.split("```json")[-1].split("```")[0]
try:
operations = json.loads(fix_json_backslashes(response_json))
Expand Down Expand Up @@ -572,6 +576,9 @@ def _single_rollout_summary(self, rollouts, save_dir):
for rollouts in problems_to_rollouts.values():
scores = [each["reward"] for each in rollouts]
avg_score = sum(scores) / max(len(scores), 1)
# ADVICE: This filter only keeps problems with mixed outcomes (some correct, some wrong).
# That can easily starve the updater when the model is very weak (all 0) or very strong (all 1).
# If you want learning signals from failures, consider also processing all-wrong cases.
if 0 < avg_score < 1:
all_rollouts_to_process.extend(rollouts)

Expand Down Expand Up @@ -609,10 +616,14 @@ def _single_query_critique(self, problem_to_summarized_rollouts, principles, sav
for rollouts in problem_to_summarized_rollouts.values():
scores = [each["reward"] for each in rollouts]
avg_score = sum(scores) / len(scores)
# ADVICE: Same starvation risk as above: only mixed-success problems are critiqued.
# Consider handling avg_score==0 (all wrong) to propose new principles, and avg_score==1 (all right)
# to optionally merge/deduplicate principles.
if 0 < avg_score < 1:
all_rollouts.append(rollouts)

sample_futures_list = []
rollouts_per_problem_list = []
for rollouts_per_problem in all_rollouts:
problem = rollouts_per_problem[0]["problem"]
answer = rollouts_per_problem[0]["groundtruth"]
Expand All @@ -636,9 +647,12 @@ def _single_query_critique(self, problem_to_summarized_rollouts, principles, sav
sample_futures_list.append(
get_sample_future(conversation, self.renderer, self.sampling_client, self.sampling_params)
)
# IMPORTANT: keep a 1:1 mapping between each Future and its corresponding rollouts.
# Otherwise, referencing `rollouts_per_problem` later would accidentally reuse the last loop value.
rollouts_per_problem_list.append(rollouts_per_problem)

results = []
for sample_future in sample_futures_list:
for rollouts_per_problem, sample_future in zip(rollouts_per_problem_list, sample_futures_list):
raw_response = get_text_from_sampled_future(sample_future, self.renderer)
# parse json
operations = get_operations_from_json(raw_response)
Expand Down Expand Up @@ -667,6 +681,10 @@ def _batch_update(self, principles, critiques, save_dir, max_retries=3):
candidate_principles = copy.deepcopy(principles)
to_modify = []
max_ID = 0
# ADVICE: ID-collision risk.
# `max_ID` starts at 0 and new principles are keyed as C0, C1, ...
# If `candidate_principles` already contains C* keys (from prior runs), this may overwrite existing ones.
# Safer: initialize max_ID from existing keys, e.g. max(int(k[1:]) for k in candidate_principles if k.startswith('C')) + 1.
for operation in all_operations:
if operation["operation"] == "modify":
if operation["modified_from"] in candidate_principles:
Expand Down Expand Up @@ -1077,6 +1095,10 @@ def load_state_dict(self, state):

def main(config: Config):
# Setup logging
# ADVICE: Reproducibility.
# This script uses `random` (evolution sampling, shuffling) and `numpy`/`torch` elsewhere.
# If you want runs to be more reproducible, consider adding a `seed` field to Config and setting
# random.seed(seed), np.random.seed(seed), and torch.manual_seed(seed) here.
ml_logger = ml_log.setup_logging(
log_dir=config.log_path,
wandb_project="system_prompt_learning",
Expand Down Expand Up @@ -1448,6 +1470,14 @@ def main(config: Config):

all_tokens = prompt_tokens + sampled_tokens
group_tokens.append(all_tokens)
# ADVICE: Logprob/offset alignment.
# We set ob_len = len(prompt_tokens) - 1 because (in standard next-token LM training)
# the first prompt token has no target/logprob, so prompt contributes (prompt_len-1)
# positions in the shifted (input, target) pair.
# This assumes `sampled_logprobs` has length == len(sampled_tokens) and corresponds
# to the generated tokens only (no extra dummy element). If your sampler returns
# logprobs with a leading dummy (like compute_logprobs() often does), this will
# silently misalign training. Consider asserting the expected lengths here.
group_ob_lens.append(len(prompt_tokens) - 1)
group_logprobs.append(sampled_logprobs)

Expand Down Expand Up @@ -1498,6 +1528,10 @@ def main(config: Config):
):
# check if all advantages are zero
if all(advantage == 0.0 for advantage in advantages):
# ADVICE: This can easily drop most data early in training.
# With binary rewards (0/1) and small group_size, it's common to get all-0 or all-1 groups,
# which makes centered advantages all zero and therefore skips the entire question.
# Consider logging the skip ratio and/or using a different baseline/advantage scheme if learning stalls.
# Skip question because all advantages are the same
continue

Expand Down