Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions course/en/chapter13/grpo_format.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "marimo",
# "plotly",
# "numpy",
# "pandas",
# ]
# ///

import marimo

__generated_with = "0.10.6"
Expand Down Expand Up @@ -121,11 +131,10 @@ def format_reward(completions, format_type="think-answer", **kwargs):
}
)

# Create a table view
mo.md(f"### Results for {format_buttons.value} format")
mo.ui.table(results)

# Create a bar chart comparing rewards by completion
display_rows = [
{**row, "Reward": f"{row['Reward']:.2f}"}
for row in results
]
fig = px.bar(
results,
x="Completion",
Expand All @@ -134,7 +143,14 @@ def format_reward(completions, format_type="think-answer", **kwargs):
title=f"Format Rewards by Completion ({format_buttons.value})",
hover_data=["Detail"],
)
mo.ui.plotly(fig)

mo.vstack(
[
mo.md(f"### Results for {format_buttons.value} format"),
mo.ui.table(display_rows, selection=None),
mo.ui.plotly(fig),
]
)


if __name__ == "__main__":
Expand Down
23 changes: 22 additions & 1 deletion course/en/chapter13/grpo_length.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "marimo",
# "plotly",
# "numpy",
# "pandas",
# ]
# ///

import marimo

__generated_with = "0.10.6"
Expand Down Expand Up @@ -70,8 +80,19 @@ def length_reward(completions, ideal_length):
{"Completion": completion, "Length": len(completion), "Reward": reward}
)

display_rows = [
{**row, "Reward": f"{row['Reward']:.2f}"}
for row in results
]
fig = px.bar(results, x="Completion", y="Reward", color="Length")
mo.ui.plotly(fig)

mo.vstack(
[
mo.md("### Reward comparison"),
mo.ui.table(display_rows, selection=None),
mo.ui.plotly(fig),
]
)


if __name__ == "__main__":
Expand Down
28 changes: 22 additions & 6 deletions course/en/chapter13/grpo_math.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "marimo",
# "plotly",
# "numpy",
# "pandas",
# ]
# ///

import marimo

__generated_with = "0.10.6"
Expand Down Expand Up @@ -122,11 +132,10 @@ def problem_reward(completions, answers, tolerance=0):
}
)

# Create a table view
mo.md("### Results")
mo.ui.table(results)

# Create a bar chart
display_rows = [
{**row, "Reward": f"{row['Reward']:.2f}"}
for row in results
]
fig = px.bar(
results,
x="Problem",
Expand All @@ -135,7 +144,14 @@ def problem_reward(completions, answers, tolerance=0):
hover_data=["Correct Answer", "Model Answer"],
title="Rewards by Problem",
)
mo.ui.plotly(fig)

mo.vstack(
[
mo.md("### Reward comparison"),
mo.ui.table(display_rows, selection=None),
mo.ui.plotly(fig),
]
)


if __name__ == "__main__":
Expand Down