Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions rlvr_math_tutor.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install label-studio-sdk\n",
"\n",
"from label_studio_sdk import Client\n",
"from dotenv import load_dotenv\n",
"import os\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"load_dotenv()\n",
"\n",
"LS_URL = \"http://localhost:8080\" \n",
"API_KEY = os.environ.get(\"API_KEY\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from label_studio_sdk.client import LabelStudio\n",
"\n",
"ls = LabelStudio(\n",
" base_url=LS_URL, \n",
" api_key=API_KEY,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"label_config = \"\"\"\n",
"<View>\n",
" <Text name=\"input\" value=\"$prompt\"/>\n",
" <Text name=\"output\" value=\"$response\"/>\n",
" <Choices name=\"reward\" toName=\"output\">\n",
" <Choice value=\"1\">Good</Choice>\n",
" <Choice value=\"0\">Bad</Choice>\n",
" </Choices>\n",
"</View>\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"verify:True\n",
"Created project 53\n"
]
}
],
"source": [
"proj = ls.projects.create(\n",
" title=\"RLVR Scoring\",\n",
" description=\"Annotate responses for RLVR training.\",\n",
" label_config=label_config, \n",
" color=\"#FF8800\"\n",
")\n",
"pid = proj.id\n",
"print(f\"Created project {pid}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Uploaded tasks\n"
]
}
],
"source": [
"tasks = [\n",
" {\n",
" \"prompt\": \"Solve for x: 2x + 3 = 7\",\n",
" \"response\": \"Subtract 3 from both sides: 2x = 4, then divide by 2. So, x = 2.\"\n",
" },\n",
" {\n",
" \"prompt\": \"What is the derivative of x^2?\",\n",
" \"response\": \"2x\"\n",
" },\n",
" {\n",
" \"prompt\": \"Calculate the area of a circle with radius 3\",\n",
" \"response\": \"To calculate the area of a circle, use the formula Area = π * r^2. With r = 3, Area = π * 3 * 3 = 9π.\"\n",
" }\n",
"]\n",
"\n",
"ls.projects.import_tasks(id=pid, request=tasks)\n",
"print(\"Uploaded tasks\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(Human annotation happens in UI)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exported 3 tasks in RLVR format (prompt, response, reward)\n"
]
}
],
"source": [
"tasks = ls.tasks.list(project=proj.id)\n",
"\n",
"dataset = []\n",
"for task in tasks:\n",
" obj = task.model_dump()\n",
" item = {\n",
" \"prompt\": obj[\"data\"].get(\"prompt\"),\n",
" \"response\": obj[\"data\"].get(\"response\"),\n",
" \"reward\": None \n",
" }\n",
" \n",
" for ann in obj.get(\"annotations\", []):\n",
" for r in ann.get(\"result\", []):\n",
" if \"value\" in r and \"choices\" in r[\"value\"]:\n",
" choice = r[\"value\"][\"choices\"][0]\n",
" if choice == \"Good\":\n",
" item[\"reward\"] = 1\n",
" elif choice == \"Bad\":\n",
" item[\"reward\"] = 0\n",
" dataset.append(item)\n",
"\n",
"\n",
"with open(\"rlvr_data.json\", \"w\") as f:\n",
" json.dump(dataset, f, indent=2)\n",
"\n",
"print(f\"Exported {len(dataset)} tasks in RLVR format (prompt, response, reward)\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}