From 4270eac5f87c100fc81849577c44a8187f3cf38a Mon Sep 17 00:00:00 2001 From: jalateras Date: Sat, 13 Sep 2025 10:43:15 +1000 Subject: [PATCH] feat: add notebook explaining ChatCompletion to SFT transformation Added comprehensive notebook demonstrating how OpenAI's fine-tuning framework internally transforms ChatCompletion-style training data into model-ready format for Supervised Fine-Tuning. Covers message concatenation, tokenization, loss masking, and training process visualization. --- ...ng_ChatCompletion_SFT_transformation.ipynb | 530 ++++++++++++++++++ registry.yaml | 10 + 2 files changed, 540 insertions(+) create mode 100644 examples/Understanding_ChatCompletion_SFT_transformation.ipynb diff --git a/examples/Understanding_ChatCompletion_SFT_transformation.ipynb b/examples/Understanding_ChatCompletion_SFT_transformation.ipynb new file mode 100644 index 0000000000..e528d6cb11 --- /dev/null +++ b/examples/Understanding_ChatCompletion_SFT_transformation.ipynb @@ -0,0 +1,530 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Understanding ChatCompletion to Model-Ready Format Transformation for SFT\n", + "\n", + "This notebook explains how OpenAI's fine-tuning framework internally transforms ChatCompletion-style training data into model-ready format for Supervised Fine-Tuning (SFT), and how the loss is computed during training.\n", + "\n", + "This addresses a common question about what happens \"under the hood\" when you provide training data in the ChatCompletion format for fine-tuning." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview of the Transformation Process\n", + "\n", + "When you provide training data in ChatCompletion format, the framework performs several transformation steps:\n", + "\n", + "1. **Message Concatenation**: Converts the structured conversation into a continuous text sequence\n", + "2. **Special Token Insertion**: Adds role markers and message boundaries\n", + "3. **Tokenization**: Converts text to token IDs that the model can process\n", + "4. **Loss Mask Creation**: Determines which tokens contribute to the training loss\n", + "5. **Sequence Padding**: Ensures uniform batch sizes for efficient training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: ChatCompletion Format Input\n", + "\n", + "Your training data starts in this familiar format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example training conversation\n", + "training_example = {\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful assistant specialized in explaining technical concepts.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is gradient descent?\"\n", + " },\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Gradient descent is an optimization algorithm used to minimize a function by iteratively moving in the direction of steepest descent.\"\n", + " }\n", + " ]\n", + "}\n", + "\n", + "import json\n", + "print(json.dumps(training_example, indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Internal Format Transformation\n", + "\n", + "The framework transforms this structured conversation into a linear sequence with special tokens that indicate role boundaries and message structure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simplified representation of the internal transformation\n", + "# Note: Actual tokens and format may vary by model\n", + "\n", + "def transform_to_model_format(messages):\n", + " \"\"\"\n", + " Demonstrates how ChatCompletion messages are transformed into\n", + " a continuous sequence for the model.\n", + " \"\"\"\n", + " # Special tokens (simplified representation)\n", + " SYSTEM_PREFIX = \"<|im_start|>system\\n\"\n", + " SYSTEM_SUFFIX = \"<|im_end|>\\n\"\n", + " USER_PREFIX = \"<|im_start|>user\\n\"\n", + " USER_SUFFIX = \"<|im_end|>\\n\"\n", + " ASSISTANT_PREFIX = \"<|im_start|>assistant\\n\"\n", + " ASSISTANT_SUFFIX = \"<|im_end|>\\n\"\n", + " \n", + " model_input = \"\"\n", + " \n", + " for message in messages:\n", + " role = message[\"role\"]\n", + " content = message[\"content\"]\n", + " \n", + " if role == \"system\":\n", + " model_input += SYSTEM_PREFIX + content + SYSTEM_SUFFIX\n", + " elif role == \"user\":\n", + " model_input += USER_PREFIX + content + USER_SUFFIX\n", + " elif role == \"assistant\":\n", + " model_input += ASSISTANT_PREFIX + content + ASSISTANT_SUFFIX\n", + " \n", + " return model_input\n", + "\n", + "# Transform our example\n", + "model_ready_format = transform_to_model_format(training_example[\"messages\"])\n", + "print(\"Model-ready format:\")\n", + "print(model_ready_format)\n", + "print(\"\\n\" + \"=\"*50)\n", + "print(\"This continuous sequence is what the model actually sees during training.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Tokenization Process\n", + "\n", + "The text sequence is then converted to numerical tokens that the model can process:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tiktoken\n", + "\n", + "# Initialize tokenizer (using cl100k_base as example)\n", + "encoding = tiktoken.get_encoding(\"cl100k_base\")\n", + "\n", + "def demonstrate_tokenization(text):\n", + " \"\"\"\n", + " Shows how text is converted to tokens.\n", + " \"\"\"\n", + " tokens = encoding.encode(text)\n", + " \n", + " print(f\"Original text length: {len(text)} characters\")\n", + " print(f\"Number of tokens: {len(tokens)}\")\n", + " print(f\"\\nFirst 20 tokens: {tokens[:20]}\")\n", + " \n", + " # Decode back to show token boundaries\n", + " print(\"\\nToken boundaries (first 100 chars):\")\n", + " for i, token_id in enumerate(tokens[:10]):\n", + " token_text = encoding.decode([token_id])\n", + " print(f\"Token {i}: [{token_id}] = '{token_text}'\")\n", + " \n", + " return tokens\n", + "\n", + "# Tokenize the model-ready format\n", + "tokens = demonstrate_tokenization(model_ready_format)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Loss Computation Strategy\n", + "\n", + "A critical aspect of SFT is determining which tokens contribute to the training loss. The framework implements intelligent loss masking:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_loss_mask(messages, tokens):\n", + " \"\"\"\n", + " Demonstrates how the loss mask is created.\n", + " In SFT, typically only the assistant's responses contribute to the loss.\n", + " \"\"\"\n", + " # This is a simplified representation\n", + " # In practice, the implementation tracks token positions more precisely\n", + " \n", + " loss_mask = []\n", + " current_role = None\n", + " \n", + " # For demonstration, we'll create a simple mask\n", + " # 1 = contribute to loss, 0 = don't contribute\n", + " for message in messages:\n", + " role = message[\"role\"]\n", + " content_tokens = encoding.encode(message[\"content\"])\n", + " \n", + " if role == \"assistant\":\n", + " # Assistant tokens contribute to loss\n", + " loss_mask.extend([1] * len(content_tokens))\n", + " else:\n", + " # System and user tokens don't contribute to loss\n", + " loss_mask.extend([0] * len(content_tokens))\n", + " \n", + " return loss_mask\n", + "\n", + "# Demonstrate loss masking\n", + "print(\"Loss Masking Strategy:\")\n", + "print(\"=\"*50)\n", + "print(\"✓ Assistant responses: Contribute to loss (mask=1)\")\n", + "print(\"✗ System messages: Don't contribute to loss (mask=0)\")\n", + "print(\"✗ User messages: Don't contribute to loss (mask=0)\")\n", + "print(\"\\nThis ensures the model learns to generate appropriate assistant responses\")\n", + "print(\"given the context of system instructions and user queries.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Training Process Visualization\n", + "\n", + "Here's how the transformed data flows through the training process:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def visualize_training_flow():\n", + " \"\"\"\n", + " Creates a visual representation of the training data flow.\n", + " \"\"\"\n", + " fig, axes = plt.subplots(3, 1, figsize=(12, 10))\n", + " \n", + " # Simplified token sequence\n", + " sequence_length = 50\n", + " token_types = ['System'] * 10 + ['User'] * 15 + ['Assistant'] * 25\n", + " \n", + " # Color mapping\n", + " colors = {'System': 'blue', 'User': 'green', 'Assistant': 'red'}\n", + " color_sequence = [colors[t] for t in token_types]\n", + " \n", + " # Plot 1: Token sequence\n", + " ax1 = axes[0]\n", + " positions = np.arange(sequence_length)\n", + " ax1.bar(positions, [1]*sequence_length, color=color_sequence, width=1.0)\n", + " ax1.set_title('Token Sequence by Role', fontsize=14, fontweight='bold')\n", + " ax1.set_ylabel('Token Presence')\n", + " ax1.set_ylim(0, 1.5)\n", + " ax1.legend([plt.Rectangle((0,0),1,1, fc=c) for c in colors.values()], \n", + " colors.keys(), loc='upper right')\n", + " \n", + " # Plot 2: Loss mask\n", + " ax2 = axes[1]\n", + " loss_mask = [0] * 10 + [0] * 15 + [1] * 25 # Only assistant tokens have loss\n", + " ax2.bar(positions, loss_mask, color=['gray' if m == 0 else 'orange' for m in loss_mask], width=1.0)\n", + " ax2.set_title('Loss Mask (Which Tokens Contribute to Training)', fontsize=14, fontweight='bold')\n", + " ax2.set_ylabel('Loss Weight')\n", + " ax2.set_ylim(0, 1.5)\n", + " ax2.legend(['No Loss', 'Has Loss'], loc='upper right')\n", + " \n", + " # Plot 3: Gradient flow\n", + " ax3 = axes[2]\n", + " gradient_magnitude = np.array(loss_mask) * np.random.uniform(0.5, 1.0, sequence_length)\n", + " ax3.plot(positions, gradient_magnitude, 'r-', linewidth=2)\n", + " ax3.fill_between(positions, 0, gradient_magnitude, alpha=0.3, color='red')\n", + " ax3.set_title('Gradient Magnitude During Backpropagation', fontsize=14, fontweight='bold')\n", + " ax3.set_ylabel('Gradient Magnitude')\n", + " ax3.set_xlabel('Token Position')\n", + " ax3.set_ylim(0, 1.5)\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('sft_transformation_flow.png', dpi=150, bbox_inches='tight')\n", + " plt.show()\n", + " \n", + " print(\"\\nKey Insights:\")\n", + " print(\"1. The entire conversation is processed as a continuous sequence\")\n", + " print(\"2. Only assistant tokens contribute to the loss calculation\")\n", + " print(\"3. Gradients flow back primarily through assistant response tokens\")\n", + " print(\"4. The model learns to generate appropriate responses in context\")\n", + "\n", + "visualize_training_flow()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Loss Calculation Details\n", + "\n", + "The actual loss computation uses cross-entropy loss on the masked tokens:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def explain_loss_calculation():\n", + " \"\"\"\n", + " Explains how the loss is calculated during SFT.\n", + " \"\"\"\n", + " print(\"Loss Calculation in SFT\")\n", + " print(\"=\"*50)\n", + " print()\n", + " print(\"1. Forward Pass:\")\n", + " print(\" - Input: Full conversation sequence (system + user + assistant)\")\n", + " print(\" - Model generates: Probability distribution over vocabulary for each token\")\n", + " print()\n", + " print(\"2. Loss Computation:\")\n", + " print(\" - For each token position:\")\n", + " print(\" • If mask = 0 (system/user): Skip this token\")\n", + " print(\" • If mask = 1 (assistant): Calculate cross-entropy loss\")\n", + " print(\" - Formula: L = -Σ(mask_i * log(P(token_i|context)))\")\n", + " print()\n", + " print(\"3. Gradient Calculation:\")\n", + " print(\" - Gradients are computed only for masked tokens\")\n", + " print(\" - This focuses learning on generating good assistant responses\")\n", + " print()\n", + " print(\"4. Parameter Update:\")\n", + " print(\" - Model weights are updated to minimize the loss\")\n", + " print(\" - Over many examples, the model learns the assistant's behavior pattern\")\n", + " \n", + " # Simulate a simple loss calculation\n", + " print(\"\\n\" + \"=\"*50)\n", + " print(\"Example Loss Calculation:\")\n", + " print(\"=\"*50)\n", + " \n", + " # Simulated probabilities for assistant tokens\n", + " assistant_tokens = [\"Gradient\", \"descent\", \"is\", \"an\", \"optimization\", \"algorithm\"]\n", + " predicted_probs = [0.95, 0.88, 0.99, 0.97, 0.82, 0.91]\n", + " \n", + " losses = [-np.log(p) for p in predicted_probs]\n", + " \n", + " for i, (token, prob, loss) in enumerate(zip(assistant_tokens, predicted_probs, losses)):\n", + " print(f\"Token {i+1}: '{token}'\")\n", + " print(f\" Predicted probability: {prob:.3f}\")\n", + " print(f\" Cross-entropy loss: {loss:.3f}\")\n", + " \n", + " avg_loss = np.mean(losses)\n", + " print(f\"\\nAverage loss for this example: {avg_loss:.3f}\")\n", + " print(\"\\nLower loss = model is more confident in generating the correct tokens\")\n", + "\n", + "explain_loss_calculation()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Practical Implications for Fine-Tuning\n", + "\n", + "Understanding this transformation process helps explain several important aspects of fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def summarize_implications():\n", + " \"\"\"\n", + " Summarizes the practical implications of the transformation process.\n", + " \"\"\"\n", + " implications = [\n", + " {\n", + " \"aspect\": \"Why Assistant Messages Matter Most\",\n", + " \"explanation\": \"Only assistant tokens contribute to loss, so the quality and accuracy of assistant responses in your training data directly impacts model performance.\"\n", + " },\n", + " {\n", + " \"aspect\": \"Token Limits\",\n", + " \"explanation\": \"The 4096 token limit applies to the entire conversation after transformation, including special tokens for role markers.\"\n", + " },\n", + " {\n", + " \"aspect\": \"Context Learning\",\n", + " \"explanation\": \"The model learns to generate responses conditioned on the full context (system + user messages), even though only assistant tokens contribute to loss.\"\n", + " },\n", + " {\n", + " \"aspect\": \"Format Consistency\",\n", + " \"explanation\": \"Maintaining consistent formatting in your training data helps the model learn the expected structure and improves generation quality.\"\n", + " },\n", + " {\n", + " \"aspect\": \"Multi-turn Conversations\",\n", + " \"explanation\": \"Including multi-turn examples helps the model learn to maintain context across multiple exchanges.\"\n", + " }\n", + " ]\n", + " \n", + " print(\"Practical Implications for Your Fine-Tuning\")\n", + " print(\"=\"*60)\n", + " \n", + " for i, item in enumerate(implications, 1):\n", + " print(f\"\\n{i}. {item['aspect']}\")\n", + " print(f\" {item['explanation']}\")\n", + " \n", + " print(\"\\n\" + \"=\"*60)\n", + " print(\"\\nBest Practices Based on This Understanding:\")\n", + " print(\"• Ensure high-quality assistant responses in your training data\")\n", + " print(\"• Include diverse examples that cover your use case\")\n", + " print(\"• Keep conversations within token limits to avoid truncation\")\n", + " print(\"• Use consistent formatting across all training examples\")\n", + " print(\"• Test with examples similar to your training format\")\n", + "\n", + "summarize_implications()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Complete Example: From ChatCompletion to Training\n", + "\n", + "Let's walk through a complete example showing the entire transformation pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def complete_transformation_example():\n", + " \"\"\"\n", + " Demonstrates the complete transformation pipeline from ChatCompletion\n", + " format to model training.\n", + " \"\"\"\n", + " # Step 1: Original ChatCompletion format\n", + " original_data = {\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": \"You are a Python expert.\"},\n", + " {\"role\": \"user\", \"content\": \"How do I read a file in Python?\"},\n", + " {\"role\": \"assistant\", \"content\": \"You can read a file using the open() function with a context manager:\\n\\nwith open('file.txt', 'r') as f:\\n content = f.read()\"}\n", + " ]\n", + " }\n", + " \n", + " print(\"STEP 1: Original ChatCompletion Format\")\n", + " print(\"=\"*50)\n", + " print(json.dumps(original_data, indent=2))\n", + " \n", + " # Step 2: Transform to linear sequence\n", + " print(\"\\n\\nSTEP 2: Transformed to Linear Sequence\")\n", + " print(\"=\"*50)\n", + " linear_sequence = transform_to_model_format(original_data[\"messages\"])\n", + " print(linear_sequence[:200] + \"...\" if len(linear_sequence) > 200 else linear_sequence)\n", + " \n", + " # Step 3: Tokenization\n", + " print(\"\\n\\nSTEP 3: Tokenization\")\n", + " print(\"=\"*50)\n", + " tokens = encoding.encode(linear_sequence)\n", + " print(f\"Total tokens: {len(tokens)}\")\n", + " print(f\"First 30 tokens: {tokens[:30]}\")\n", + " \n", + " # Step 4: Loss mask creation\n", + " print(\"\\n\\nSTEP 4: Loss Mask Creation\")\n", + " print(\"=\"*50)\n", + " \n", + " # Simplified: Find where assistant content starts\n", + " assistant_start_marker = \"<|im_start|>assistant\"\n", + " assistant_start_pos = linear_sequence.find(assistant_start_marker)\n", + " \n", + " if assistant_start_pos != -1:\n", + " # Create a simple mask\n", + " pre_assistant_tokens = encoding.encode(linear_sequence[:assistant_start_pos])\n", + " mask = [0] * len(pre_assistant_tokens) + [1] * (len(tokens) - len(pre_assistant_tokens))\n", + " \n", + " print(f\"Tokens before assistant response: {len(pre_assistant_tokens)} (mask=0)\")\n", + " print(f\"Assistant response tokens: {len(tokens) - len(pre_assistant_tokens)} (mask=1)\")\n", + " print(f\"\\nOnly the {len(tokens) - len(pre_assistant_tokens)} assistant tokens will contribute to the loss.\")\n", + " \n", + " # Step 5: Training implications\n", + " print(\"\\n\\nSTEP 5: Training Process\")\n", + " print(\"=\"*50)\n", + " print(\"During training:\")\n", + " print(\"1. The model processes the entire sequence\")\n", + " print(\"2. It predicts the next token at each position\")\n", + " print(\"3. Loss is calculated only for assistant tokens\")\n", + " print(\"4. The model learns to generate appropriate Python help responses\")\n", + " print(\"5. After many examples, it generalizes to new Python questions\")\n", + "\n", + "complete_transformation_example()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This notebook has demonstrated how OpenAI's fine-tuning framework transforms ChatCompletion format training data into model-ready format for Supervised Fine-Tuning (SFT). Key takeaways:\n", + "\n", + "1. **Structured to Sequential**: The conversation structure is converted to a linear sequence with special tokens\n", + "2. **Selective Loss Computation**: Only assistant responses contribute to the training loss\n", + "3. **Context-Aware Learning**: The model learns to generate responses based on the full conversation context\n", + "4. **Efficient Training**: The loss masking ensures the model focuses on learning the desired behavior\n", + "\n", + "Understanding this process helps you:\n", + "- Design better training data\n", + "- Understand why certain formatting matters\n", + "- Debug fine-tuning issues more effectively\n", + "- Optimize your training examples for better results\n", + "\n", + "For more information on fine-tuning best practices, refer to the [OpenAI Fine-tuning Guide](https://platform.openai.com/docs/guides/fine-tuning)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/registry.yaml b/registry.yaml index dde9894426..fb5d8a8bde 100644 --- a/registry.yaml +++ b/registry.yaml @@ -660,6 +660,16 @@ - tiktoken - completions +- title: Understanding ChatCompletion to Model-Ready Format Transformation for SFT + path: examples/Understanding_ChatCompletion_SFT_transformation.ipynb + date: 2025-01-13 + authors: + - openai + tags: + - completions + - fine-tuning + - chatcompletion + - title: How to fine-tune chat models path: examples/How_to_finetune_chat_models.ipynb date: 2024-07-23