diff --git a/course/en/chapter13/grpo_format.py b/course/en/chapter13/grpo_format.py index 02658c74e..442a4871a 100644 --- a/course/en/chapter13/grpo_format.py +++ b/course/en/chapter13/grpo_format.py @@ -1,3 +1,13 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# "plotly", +# "numpy", +# "pandas", +# ] +# /// + import marimo __generated_with = "0.10.6" @@ -121,11 +131,10 @@ def format_reward(completions, format_type="think-answer", **kwargs): } ) - # Create a table view - mo.md(f"### Results for {format_buttons.value} format") - mo.ui.table(results) - - # Create a bar chart comparing rewards by completion + display_rows = [ + {**row, "Reward": f"{row['Reward']:.2f}"} + for row in results + ] fig = px.bar( results, x="Completion", @@ -134,7 +143,14 @@ def format_reward(completions, format_type="think-answer", **kwargs): title=f"Format Rewards by Completion ({format_buttons.value})", hover_data=["Detail"], ) - mo.ui.plotly(fig) + + mo.vstack( + [ + mo.md(f"### Results for {format_buttons.value} format"), + mo.ui.table(display_rows, selection=None), + mo.ui.plotly(fig), + ] + ) if __name__ == "__main__": diff --git a/course/en/chapter13/grpo_length.py b/course/en/chapter13/grpo_length.py index ce5917a8d..bdc302c58 100644 --- a/course/en/chapter13/grpo_length.py +++ b/course/en/chapter13/grpo_length.py @@ -1,3 +1,13 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# "plotly", +# "numpy", +# "pandas", +# ] +# /// + import marimo __generated_with = "0.10.6" @@ -70,8 +80,19 @@ def length_reward(completions, ideal_length): {"Completion": completion, "Length": len(completion), "Reward": reward} ) + display_rows = [ + {**row, "Reward": f"{row['Reward']:.2f}"} + for row in results + ] fig = px.bar(results, x="Completion", y="Reward", color="Length") - mo.ui.plotly(fig) + + mo.vstack( + [ + mo.md("### Reward comparison"), + mo.ui.table(display_rows, selection=None), + mo.ui.plotly(fig), + ] + ) if __name__ == "__main__": diff --git a/course/en/chapter13/grpo_math.py b/course/en/chapter13/grpo_math.py index 7c554805c..0f9d95aa7 100644 --- a/course/en/chapter13/grpo_math.py +++ b/course/en/chapter13/grpo_math.py @@ -1,3 +1,13 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# "plotly", +# "numpy", +# "pandas", +# ] +# /// + import marimo __generated_with = "0.10.6" @@ -122,11 +132,10 @@ def problem_reward(completions, answers, tolerance=0): } ) - # Create a table view - mo.md("### Results") - mo.ui.table(results) - - # Create a bar chart + display_rows = [ + {**row, "Reward": f"{row['Reward']:.2f}"} + for row in results + ] fig = px.bar( results, x="Problem", @@ -135,7 +144,14 @@ def problem_reward(completions, answers, tolerance=0): hover_data=["Correct Answer", "Model Answer"], title="Rewards by Problem", ) - mo.ui.plotly(fig) + + mo.vstack( + [ + mo.md("### Reward comparison"), + mo.ui.table(display_rows, selection=None), + mo.ui.plotly(fig), + ] + ) if __name__ == "__main__":