diff --git a/.github/workflows/lint-and-test.yaml b/.github/workflows/lint-and-test.yaml index 8aadbc0..8365bda 100644 --- a/.github/workflows/lint-and-test.yaml +++ b/.github/workflows/lint-and-test.yaml @@ -9,6 +9,7 @@ on: branches: - main - dev + - '!f/pypi_release' jobs: test-integration: diff --git a/.github/workflows/publish_testpypi.yaml b/.github/workflows/publish_testpypi.yaml index b3da109..a090efa 100644 --- a/.github/workflows/publish_testpypi.yaml +++ b/.github/workflows/publish_testpypi.yaml @@ -57,6 +57,7 @@ jobs: with: sparse-checkout: | tests/* + examples/data_factory.ipynb sparse-checkout-cone-mode: false - name: Update system packages run: | @@ -88,13 +89,13 @@ jobs: --ExecutePreprocessor.timeout=120 \ --no-prompt --no-input \ --stdout \ - tests/data_factory/factory/test_resume_index_1.ipynb; then + examples/data_factory.ipynb; then echo "::error::Notebook execution failed" fi echo "Notebook executed successfully. Summary:" && \ jupyter nbconvert --to markdown --stdout \ - tests/data_factory/factory/test_resume_index_1.ipynb | \ + examples/data_factory.ipynb | \ grep -E '^#|^##' || true # Add tag deletion step diff --git a/.gitignore b/.gitignore index a28f571..c0477d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ # Adhoc stuff +input.json +output.json .serena/ docs/ /vibe_coding/response.md diff --git a/internal b/internal index 0cb9662..9a7ccce 160000 --- a/internal +++ b/internal @@ -1 +1 @@ -Subproject commit 0cb9662b18a4da964ebde579e6b95c53ebfc700c +Subproject commit 9a7ccce145ab67429334c8f2fa24e444df149cd5 diff --git a/prebuilt_template/README.md b/prebuilt_template/README.md index 0c46a1e..97b8694 100644 --- a/prebuilt_template/README.md +++ b/prebuilt_template/README.md @@ -14,6 +14,32 @@ Data generation templates are **prebuilt** that encapsulate sophisticated data g 4. **Generate Data**: Run the template to produce high-quality synthetic data 5. **Export & Use**: Data comes ready for training, testing, or evaluation +## Use the data-template CLI like this: +``` +# List all templates +data-template list-templates + +# List with details +data-template list-templates --detail + +# Get template details +data-template get-template my_template + +# Print schema +data-template print-schema my_template + +# Print example +data-template print-example my_template + +# Run template with interactive input +data-template run-template my_template + +# Run template with input file +data-template run-template my_template --input-file input.json + +# Run template and save output +data-template run-template my_template --input-file input.json --output-file output.json +``` ## Source Code Location The actual implementation of these templates can be found in: @@ -21,6 +47,8 @@ The actual implementation of these templates can be found in: src/starfish/data_gen_template/templates/ ``` + + ## Community & Contributions 🀝 Like what you see? We'd love your help in expanding our template collection! Here's how you can get involved: diff --git a/prebuilt_template/function_calling/sample_run.ipynb b/prebuilt_template/function_calling/sample_run.ipynb index c4934d2..1195856 100644 --- a/prebuilt_template/function_calling/sample_run.ipynb +++ b/prebuilt_template/function_calling/sample_run.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -38,6 +38,206 @@ "loaded = data_gen_template.get(\"starfish/generate_func_call_dataset\")\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "get the template input_data schema and example" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-23 11:08:41\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1mPlease run the template with this input schema\u001b[0m\n", + "\u001b[32m2025-05-23 11:08:41\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m{\n", + " \"$defs\": {\n", + " \"APIContract\": {\n", + " \"description\": \"Pydantic model representing an API contract structure.\",\n", + " \"properties\": {\n", + " \"name\": {\n", + " \"title\": \"Name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"description\": {\n", + " \"title\": \"Description\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"parameters\": {\n", + " \"additionalProperties\": {\n", + " \"$ref\": \"#/$defs/ParameterDefinition\"\n", + " },\n", + " \"title\": \"Parameters\",\n", + " \"type\": \"object\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"name\",\n", + " \"description\",\n", + " \"parameters\"\n", + " ],\n", + " \"title\": \"APIContract\",\n", + " \"type\": \"object\"\n", + " },\n", + " \"ParameterDefinition\": {\n", + " \"description\": \"Pydantic model representing parameter definition in an API contract.\",\n", + " \"properties\": {\n", + " \"type\": {\n", + " \"title\": \"Type\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"description\": {\n", + " \"title\": \"Description\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"required\": {\n", + " \"default\": true,\n", + " \"title\": \"Required\",\n", + " \"type\": \"boolean\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"type\",\n", + " \"description\"\n", + " ],\n", + " \"title\": \"ParameterDefinition\",\n", + " \"type\": \"object\"\n", + " }\n", + " },\n", + " \"description\": \"Input schema for the generate_by_topic template.\\n\\nIMPORTANT: This Pydantic model is the single source of truth for default values.\\nThe validation and default values are controlled by this model, not the function signature.\",\n", + " \"properties\": {\n", + " \"num_records\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"type\": \"integer\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": 10,\n", + " \"title\": \"Num Records\"\n", + " },\n", + " \"api_contract\": {\n", + " \"$ref\": \"#/$defs/APIContract\"\n", + " },\n", + " \"topic_model_name\": {\n", + " \"default\": \"openai/gpt-4o-mini\",\n", + " \"title\": \"Topic Model Name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"topic_model_kwargs\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"Topic Model Kwargs\"\n", + " },\n", + " \"generation_model_name\": {\n", + " \"default\": \"openai/gpt-4o-mini\",\n", + " \"title\": \"Generation Model Name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"generation_model_kwargs\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"Generation Model Kwargs\"\n", + " },\n", + " \"data_factory_config\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": {},\n", + " \"title\": \"Data Factory Config\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"api_contract\"\n", + " ],\n", + " \"title\": \"GenerateFuncCallDataSet\",\n", + " \"type\": \"object\"\n", + "}\u001b[0m\n" + ] + } + ], + "source": [ + "loaded.print_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-23 11:09:02\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1mHere is an example with api_contract.name as weather_api.get_current_weather\u001b[0m\n", + "\u001b[32m2025-05-23 11:09:02\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m{\n", + " \"num_records\": 4,\n", + " \"api_contract\": {\n", + " \"name\": \"weather_api.get_current_weather\",\n", + " \"description\": \"Retrieves the current weather conditions for a specified location .\",\n", + " \"parameters\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The name of the city or geographic location .\",\n", + " \"required\": true\n", + " },\n", + " \"units\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .\",\n", + " \"required\": false\n", + " }\n", + " }\n", + " },\n", + " \"topic_model_name\": \"openai/gpt-4\",\n", + " \"topic_model_kwargs\": {\n", + " \"temperature\": 0.7\n", + " },\n", + " \"generation_model_name\": \"openai/gpt-4o-mini\",\n", + " \"generation_model_kwargs\": {\n", + " \"temperature\": 0.8,\n", + " \"max_tokens\": 200\n", + " },\n", + " \"data_factory_config\": {\n", + " \"max_concurrency\": 24,\n", + " \"task_runner_timeout\": 120\n", + " }\n", + "}\u001b[0m\n" + ] + } + ], + "source": [ + "loaded.print_example()" + ] + }, { "cell_type": "code", "execution_count": 5, @@ -203,7 +403,7 @@ ], "metadata": { "kernelspec": { - "display_name": "starfish-core-T7IInzTH-py3.11", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -217,7 +417,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/prebuilt_template/generate_by_topic/README.md b/prebuilt_template/generate_by_topic/README.md new file mode 100644 index 0000000..9d5dc10 --- /dev/null +++ b/prebuilt_template/generate_by_topic/README.md @@ -0,0 +1,102 @@ + +## Overview +The `generate_by_topic` template is designed to create diverse synthetic data across multiple topics based on user instructions. It can automatically generate relevant topics if not provided and handles deduplication across generated content. + +## Key Features +- Automatic topic generation based on user instructions +- Customizable number of records and records per topic +- Built-in deduplication mechanism +- Flexible output schema configuration +- Parallel data generation with configurable concurrency + +## Input Schema +```python +class GenerateByTopicInput(BaseModel): + user_instruction: Optional[str] = None + num_records: Optional[int] = 10 + records_per_topic: int = 10 + topics: Optional[List[Union[str, Dict[str, int]]]] = None + topic_model_name: str = "openai/gpt-4o-mini" + topic_model_kwargs: Optional[Dict[str, Any]] = None + generation_model_name: str = "openai/gpt-4o-mini" + generation_model_kwargs: Optional[Dict[str, Any]] = None + output_schema: Optional[Union[List[Dict[str, Any]], Dict[str, Any], type]] = [ + {"name": "question", "type": "str"}, + {"name": "answer", "type": "str"} + ] + data_factory_config: Optional[Dict[str, Any]] = {} +``` + +## Parameters +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `user_instruction` | str | Instruction for data generation | None | +| `num_records` | int | Total number of records to generate | 10 | +| `records_per_topic` | int | Number of records per topic | 10 | +| `topics` | List[Union[str, Dict[str, int]]] | List of topics or topic with specific record count | None | +| `topic_model_name` | str | Model name for topic generation | "openai/gpt-4o-mini" | +| `topic_model_kwargs` | Dict[str, Any] | Additional parameters for topic model | None | +| `generation_model_name` | str | Model name for data generation | "openai/gpt-4o-mini" | +| `generation_model_kwargs` | Dict[str, Any] | Additional parameters for generation model | None | +| `output_schema` | Union[List[Dict[str, Any]], Dict[str, Any], type] | Schema for generated data | [{"name": "question", "type": "str"}, {"name": "answer", "type": "str"}] | +| `data_factory_config` | Dict[str, Any] | Configuration for data generation process | {} | + +## Example Usage +```python +{ + "user_instruction": "Generate Q&A pairs about machine learning concepts", + "num_records": 100, + "records_per_topic": 5, + "topics": [ + "supervised learning", + "unsupervised learning", + {"reinforcement learning": 3}, + "neural networks", + ], + "topic_model_name": "openai/gpt-4", + "topic_model_kwargs": {"temperature": 0.7}, + "generation_model_name": "openai/gpt-4", + "generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, + "output_schema": [ + {"name": "question", "type": "str"}, + {"name": "answer", "type": "str"}, + {"name": "difficulty", "type": "str"}, + ], + "data_factory_config": {"max_concurrency": 4, "task_runner_timeout": 60 * 2}, +} +``` + +## Workflow +1. Topic Preparation: + - If topics are not provided, generates relevant topics based on user instruction + - Shuffles topics for better distribution and deduplication + +2. Data Generation: + - Generates data for each topic using the specified model + - Implements deduplication by tracking previously generated examples + - Adds topic information to each generated record + +## Output +The generated data will include: +- Fields specified in the output schema +- An additional `topic` field indicating the topic of each record + +## Dependencies +- `starfish` framework +- `pydantic` for input validation + + +## Sample Run + +Check out [`sample_run.ipynb`](./sample_run.ipynb) for a complete example you can run right away. + +## Source Implementation + +The actual template code is located at: +``` +src/starfish/data_gen_template/templates/starfish/generate_by_topic/ +``` + +--- + +**Try it out!** If you have any questions, let us know - we'd be happy to help. If you like this template, consider starring the repo and building your own! We welcome community contributions and are always happy to chat about new ideas. ⭐ \ No newline at end of file diff --git a/prebuilt_template/generate_by_topic/sample_run.ipynb b/prebuilt_template/generate_by_topic/sample_run.ipynb new file mode 100644 index 0000000..a55a46e --- /dev/null +++ b/prebuilt_template/generate_by_topic/sample_run.ipynb @@ -0,0 +1,438 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from starfish import data_gen_template" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['starfish/generate_func_call_dataset', 'starfish/generate_by_topic']" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_gen_template.list()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = data_gen_template.get(\"starfish/generate_by_topic\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "get the template input_data schema and example" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-23 11:23:57\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1mPlease run the template with this input schema\u001b[0m\n", + "\u001b[32m2025-05-23 11:23:57\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m{\n", + " \"description\": \"Input schema for the generate_by_topic template.\\n\\nIMPORTANT: This Pydantic model is the single source of truth for default values.\\nThe validation and default values are controlled by this model, not the function signature.\",\n", + " \"properties\": {\n", + " \"user_instruction\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"type\": \"string\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"User Instruction\"\n", + " },\n", + " \"num_records\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"type\": \"integer\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": 10,\n", + " \"title\": \"Num Records\"\n", + " },\n", + " \"records_per_topic\": {\n", + " \"default\": 10,\n", + " \"title\": \"Records Per Topic\",\n", + " \"type\": \"integer\"\n", + " },\n", + " \"topics\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"items\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"type\": \"string\"\n", + " },\n", + " {\n", + " \"additionalProperties\": {\n", + " \"type\": \"integer\"\n", + " },\n", + " \"type\": \"object\"\n", + " }\n", + " ]\n", + " },\n", + " \"type\": \"array\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"Topics\"\n", + " },\n", + " \"topic_model_name\": {\n", + " \"default\": \"openai/gpt-4o-mini\",\n", + " \"title\": \"Topic Model Name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"topic_model_kwargs\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"Topic Model Kwargs\"\n", + " },\n", + " \"generation_model_name\": {\n", + " \"default\": \"openai/gpt-4o-mini\",\n", + " \"title\": \"Generation Model Name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"generation_model_kwargs\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": null,\n", + " \"title\": \"Generation Model Kwargs\"\n", + " },\n", + " \"output_schema\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"items\": {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " \"type\": \"array\"\n", + " },\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": [\n", + " {\n", + " \"name\": \"question\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"answer\",\n", + " \"type\": \"str\"\n", + " }\n", + " ],\n", + " \"title\": \"Output Schema\"\n", + " },\n", + " \"data_factory_config\": {\n", + " \"anyOf\": [\n", + " {\n", + " \"additionalProperties\": true,\n", + " \"type\": \"object\"\n", + " },\n", + " {\n", + " \"type\": \"null\"\n", + " }\n", + " ],\n", + " \"default\": {},\n", + " \"title\": \"Data Factory Config\"\n", + " }\n", + " },\n", + " \"title\": \"GenerateByTopicInput\",\n", + " \"type\": \"object\"\n", + "}\u001b[0m\n" + ] + } + ], + "source": [ + "loaded.print_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-23 11:24:01\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1mHere is an example with api_contract.name as weather_api.get_current_weather\u001b[0m\n", + "\u001b[32m2025-05-23 11:24:01\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m{\n", + " \"user_instruction\": \"Generate Q&A pairs about machine learning concepts\",\n", + " \"num_records\": 100,\n", + " \"records_per_topic\": 5,\n", + " \"topics\": [\n", + " \"supervised learning\",\n", + " \"unsupervised learning\",\n", + " {\"reinforcement learning\": 3}, # This means generate 3 records for this topic\n", + " \"neural networks\",\n", + " ],\n", + " \"topic_model_name\": \"openai/gpt-4\",\n", + " \"topic_model_kwargs\": {\"temperature\": 0.7},\n", + " \"generation_model_name\": \"openai/gpt-4\",\n", + " \"generation_model_kwargs\": {\"temperature\": 0.8, \"max_tokens\": 200},\n", + " \"output_schema\": [\n", + " {\"name\": \"question\", \"type\": \"str\"},\n", + " {\"name\": \"answer\", \"type\": \"str\"},\n", + " {\"name\": \"difficulty\", \"type\": \"str\"}, # Added an additional field\n", + " ],\n", + " \"data_factory_config\": {\"max_concurrency\": 4, \"task_runner_timeout\": 60 * 2},\n", + " }\u001b[0m\n" + ] + } + ], + "source": [ + "loaded.print_example()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🌟 Function Calling Dataset Generation Pipeline\n", + "============================================================\n", + "πŸ“‹ Process Overview:\n", + " 1. Calculate optimal data distribution\n", + " 2. Generate diverse topics\n", + " 3. Create subtopics for each topic\n", + " 4. Generate query-answer pairs\n", + " 5. Verify and validate generated data\n", + " 6. Regenerate failed cases\n", + "============================================================\n", + "πŸ“Š Data Distribution Plan:\n", + " β€’ Requested: 10 records\n", + " β€’ Distribution: 1 topics Γ— 1 subtopics Γ— 10 records\n", + " β€’ Total generation: 10 records\n", + " β€’ API calls needed: 3\n", + "\n", + "🎯 Step 1: Generating diverse topics...\n", + " βœ… Generated 1 topics\n", + "\n", + "🌿 Step 2: Creating subtopics for each topic...\n", + "\u001b[32m2025-05-23 00:27:04\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m\u001b[1m[JOB START]\u001b[0m \u001b[36mMaster Job ID: e6763e50-6438-4df5-81a9-5a68ce3f8468\u001b[0m | \u001b[33mLogging progress every 3 seconds\u001b[0m\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:04\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:06\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB FINISHED] \u001b[1mFinal Status:\u001b[0m \u001b[32mCompleted: 1/1\u001b[0m | \u001b[33mAttempted: 1\u001b[0m (Failed: 0, Filtered: 0, Duplicate: 0, InDeadQueue: 0)\u001b[0m\n", + " βœ… Generated 1 subtopics total\n", + "\n", + "πŸ’¬ Step 3: Generating query-answer pairs...\n", + "\u001b[32m2025-05-23 00:27:06\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m\u001b[1m[JOB START]\u001b[0m \u001b[36mMaster Job ID: 1931c5c8-c1f3-4268-98b7-1a5295b8abf2\u001b[0m | \u001b[33mLogging progress every 3 seconds\u001b[0m\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:06\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:09\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:12\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:15\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:18\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:21\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:24\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:27\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB FINISHED] \u001b[1mFinal Status:\u001b[0m \u001b[32mCompleted: 1/1\u001b[0m | \u001b[33mAttempted: 1\u001b[0m (Failed: 0, Filtered: 0, Duplicate: 0, InDeadQueue: 0)\u001b[0m\n", + " βœ… Generated 10 initial query-answer pairs\n", + "\n", + "πŸ” Step 4: Verifying data quality...\n", + "\u001b[32m2025-05-23 00:27:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m\u001b[1m[JOB START]\u001b[0m \u001b[36mMaster Job ID: f036c07c-1cd2-4690-be92-bac359e45544\u001b[0m | \u001b[33mLogging progress every 3 seconds\u001b[0m\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:28\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/10\u001b[0m | \u001b[33mRunning: 10\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:31\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/10\u001b[0m | \u001b[33mRunning: 10\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:34\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 9/10\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 9\u001b[0m (\u001b[32mCompleted: 9\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:35\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB FINISHED] \u001b[1mFinal Status:\u001b[0m \u001b[32mCompleted: 10/10\u001b[0m | \u001b[33mAttempted: 10\u001b[0m (Failed: 0, Filtered: 0, Duplicate: 0, InDeadQueue: 0)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:35\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[33m\u001b[1mCannot serialize function for resume due to unsupported type: cannot pickle '_hashlib.HMAC' object\u001b[0m\n", + " βœ… Quality check complete: 9 passed, 1 failed\n", + "\n", + "πŸ”„ Step 5: Regenerating failed cases...\n", + "\u001b[32m2025-05-23 00:27:35\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m\u001b[1m[JOB START]\u001b[0m \u001b[36mMaster Job ID: 3d6183a2-e465-4807-9e18-cbb84dc0d28f\u001b[0m | \u001b[33mLogging progress every 3 seconds\u001b[0m\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:35\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:37\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB FINISHED] \u001b[1mFinal Status:\u001b[0m \u001b[32mCompleted: 1/1\u001b[0m | \u001b[33mAttempted: 1\u001b[0m (Failed: 0, Filtered: 0, Duplicate: 0, InDeadQueue: 0)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:37\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m\u001b[1m[JOB START]\u001b[0m \u001b[36mMaster Job ID: 8754bec6-25e3-40bd-9743-f2763fc1091f\u001b[0m | \u001b[33mLogging progress every 3 seconds\u001b[0m\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:37\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:40\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB PROGRESS] \u001b[32mCompleted: 0/1\u001b[0m | \u001b[33mRunning: 1\u001b[0m | \u001b[36mAttempted: 0\u001b[0m (\u001b[32mCompleted: 0\u001b[0m, \u001b[31mFailed: 0\u001b[0m, \u001b[35mFiltered: 0\u001b[0m, \u001b[34mDuplicate: 0\u001b[0m, \u001b[1;31mInDeadQueue: 0\u001b[0m)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:41\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[1m[JOB FINISHED] \u001b[1mFinal Status:\u001b[0m \u001b[32mCompleted: 1/1\u001b[0m | \u001b[33mAttempted: 1\u001b[0m (Failed: 0, Filtered: 0, Duplicate: 0, InDeadQueue: 0)\u001b[0m\n", + "\u001b[32m2025-05-23 00:27:41\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[33m\u001b[1mCannot serialize function for resume due to unsupported type: cannot pickle '_hashlib.HMAC' object\u001b[0m\n", + " βœ… Regenerated 1 pairs, 1 still failing\n", + "\u001b[32m2025-05-23 00:27:41\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[33m\u001b[1mSome data still failing after regeneration - prompts may need improvement\u001b[0m\n", + "🎯 Perfect! Generated exactly 10 records as requested\n", + "\n", + "πŸŽ‰ Generation Complete!\n", + "============================================================\n", + "πŸ“ˆ Final Results:\n", + " β€’ Records generated: 10\n", + " β€’ Success rate: 10/10 (100.0%)\n", + " β€’ Distribution used: 1T Γ— 1S Γ— 10R\n", + "\n", + "⭐ If you found this helpful, please consider starring our repo!\n", + " Your support means the world to us! 🌟\n", + "============================================================\n" + ] + } + ], + "source": [ + "input_data = {\n", + " \"user_instruction\": \"Generate Q&A pairs about machine learning concepts\",\n", + " \"num_records\": 100,\n", + " \"records_per_topic\": 5,\n", + " \"topics\": [\n", + " \"supervised learning\",\n", + " \"unsupervised learning\",\n", + " {\"reinforcement learning\": 3}, # This means generate 3 records for this topic\n", + " \"neural networks\",\n", + " ],\n", + " \"topic_model_name\": \"openai/gpt-4\",\n", + " \"topic_model_kwargs\": {\"temperature\": 0.7},\n", + " \"generation_model_name\": \"openai/gpt-4\",\n", + " \"generation_model_kwargs\": {\"temperature\": 0.8, \"max_tokens\": 200},\n", + " \"output_schema\": [\n", + " {\"name\": \"question\", \"type\": \"str\"},\n", + " {\"name\": \"answer\", \"type\": \"str\"},\n", + " {\"name\": \"difficulty\", \"type\": \"str\"}, # Added an additional field\n", + " ],\n", + " \"data_factory_config\": {\"max_concurrency\": 4, \"task_runner_timeout\": 60 * 2},\n", + " }\n", + "data = await loaded.run(input_data=input_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'query': 'Can you check the current weather in Toronto and Rome? Use Fahrenheit for both locations.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Toronto', 'units': 'Fahrenheit'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Rome', 'units': 'Fahrenheit'}}]},\n", + " {'query': 'Get me the current weather in Mumbai and also in Johannesburg, please use Fahrenheit for both.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Mumbai', 'units': 'Fahrenheit'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Johannesburg', 'units': 'Fahrenheit'}}]},\n", + " {'query': 'I need the current weather for Sydney and London. What are the temperatures in Celsius?',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Sydney', 'units': 'Celsius'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'London', 'units': 'Celsius'}}]},\n", + " {'query': 'Please find the current weather in Buenos Aires and Cape Town, using Celsius for Buenos Aires.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Buenos Aires', 'units': 'Celsius'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Cape Town'}}]},\n", + " {'query': 'What’s the weather like in Moscow? Also, can you get the current conditions in Beijing?',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Moscow'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Beijing'}}]},\n", + " {'query': 'Can you tell me the current weather in Tokyo and in Los Angeles? Please provide both in Fahrenheit.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Tokyo', 'units': 'Fahrenheit'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Los Angeles', 'units': 'Fahrenheit'}}]},\n", + " {'query': 'Please provide the current weather for Berlin and Cairo, using Celsius for Berlin and no specific unit for Cairo.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Berlin', 'units': 'Celsius'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Cairo'}}]},\n", + " {'query': 'I need the current weather in Seattle and in Santiago. Use Fahrenheit for Seattle and Celsius for Santiago.',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Seattle', 'units': 'Fahrenheit'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Santiago', 'units': 'Celsius'}}]},\n", + " {'query': \"What's the current temperature in San Francisco? Can you also check the weather in Paris?\",\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'San Francisco'}},\n", + " {'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'Paris'}}]},\n", + " {'query': 'What is the current weather in New York City? And can you also provide the temperature in Celsius?',\n", + " 'answer': [{'name': 'weather_api.get_current_weather',\n", + " 'arguments': {'location': 'New York City', 'units': 'Celsius'}}]}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 908b389..da0c709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ nbval = "^0.11.0" [tool.poetry.scripts] starfish = "starfish.api.cli:main" +data-template = "src.starfish.data_gen_template.cli:main" [tool.ruff] diff --git a/pytest.ini b/pytest.ini index bf3fe09..09da31e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,7 @@ [pytest] asyncio_mode = auto timeout = 300 -timeout_method = thread \ No newline at end of file +timeout_method = thread +norecursedirs = .ipynb_checkpoints +python_files = test_*.py +ignore = tests/data_factory/factory/data_factory.ipynb \ No newline at end of file diff --git a/src/starfish/data_gen_template/cli.py b/src/starfish/data_gen_template/cli.py new file mode 100644 index 0000000..25a8015 --- /dev/null +++ b/src/starfish/data_gen_template/cli.py @@ -0,0 +1,117 @@ +import typer +from pathlib import Path +from typing import Optional +from starfish.data_gen_template.core import data_gen_template + +app = typer.Typer(help="Data Template CLI") + + +@app.command() +def list_templates(detail: bool = False): + """List all available templates""" + templates = data_gen_template.list(is_detail=detail) + if detail: + for template in templates: + typer.echo(f"Template: {template['name']}") + typer.echo(f" Description: {template['description']}") + typer.echo(f" Author: {template['author']}") + typer.echo(f" Version: {template['starfish_version']}") + typer.echo(f" Dependencies: {', '.join(template.get('dependencies', []))}") + typer.echo() + else: + for template in templates: + typer.echo(template) + + +@app.command() +def get_template(name: str): + """Get details about a specific template""" + try: + data_gen_template.list() + template = data_gen_template.get(name) + typer.echo(f"Template: {template.name}") + typer.echo(f"Description: {template.description}") + typer.echo(f"Author: {template.author}") + typer.echo(f"Version: {template.starfish_version}") + typer.echo(f"Dependencies: {', '.join(template.dependencies)}") + except Exception as e: + typer.echo(f"Error: {str(e)}", err=True) + + +# @app.command() +# def export_template(name: str, output_path: str): +# """Export a template to a specific path""" +# try: +# template = data_gen_template.get(name) +# exported_path = template.export(output_path) +# typer.echo(f"Template exported to: {exported_path}") +# except Exception as e: +# typer.echo(f"Error: {str(e)}", err=True) + + +@app.command() +def run_template( + name: str, + input_file: Optional[Path] = typer.Option(None, help="Path to JSON file with input data"), + output_file: Optional[Path] = typer.Option(None, help="Path to save output to"), +): + """Run a template with the provided input data""" + try: + data_gen_template.list() + template = data_gen_template.get(name) + + # Load input data + if input_file: + import json + + with open(input_file) as f: + input_data = json.load(f) + else: + typer.echo("Please enter the input data (JSON format):") + input_data = json.loads(typer.prompt("Input data")) + + # Run the template + import asyncio + + result = asyncio.run(template.run(input_data=input_data)) + + # Handle output + if output_file: + with open(output_file, "w") as f: + json.dump(result, f, indent=2) + typer.echo(f"Output saved to {output_file}") + else: + typer.echo(json.dumps(result, indent=2)) + + except Exception as e: + typer.echo(f"Error: {str(e)}", err=True) + + +@app.command() +def print_schema(name: str): + """Print the input schema for a template""" + try: + data_gen_template.list() + template = data_gen_template.get(name) + template.print_schema() + except Exception as e: + typer.echo(f"Error: {str(e)}", err=True) + + +@app.command() +def print_example(name: str): + """Print an example input for a template""" + try: + data_gen_template.list() + template = data_gen_template.get(name) + template.print_example() + except Exception as e: + typer.echo(f"Error: {str(e)}", err=True) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/src/starfish/data_gen_template/core.py b/src/starfish/data_gen_template/core.py index 5b1417c..deddde1 100644 --- a/src/starfish/data_gen_template/core.py +++ b/src/starfish/data_gen_template/core.py @@ -4,10 +4,13 @@ import pydantic import importlib.util import ast -from typing import Any, Union, List, Dict +from typing import Any, Union, List, Dict, get_type_hints import inspect - +import json from starfish.data_gen_template.utils.errors import DataTemplateValueError, ImportModuleError, ImportPackageError +from starfish.common.logger import get_logger + +logger = get_logger(__name__) def _check_dependencies(dependencies: list[str]) -> None: @@ -40,11 +43,21 @@ class Template: """Class representing a single template instance.""" def __init__( - self, name: str, func: callable, input_schema: type, output_schema: type, description: str, author: str, starfish_version: str, dependencies: list[str] + self, + name: str, + func: callable, + input_schema: type, + input_example: str, + output_schema: type, + description: str, + author: str, + starfish_version: str, + dependencies: list[str], ): self.name = name self.func = func self.input_schema = input_schema + self.input_example = input_example self.output_schema = output_schema self.description = description self.author = author @@ -113,6 +126,17 @@ async def run(self, *args, **kwargs) -> Any: return result + def print_schema(self): + type_hints = get_type_hints(self.func) + input_schema = type_hints.get("input_data").schema() + # Pretty print the schema + logger.info("Please run the template with this input schema") + logger.info(json.dumps(input_schema, indent=4)) + + def print_example(self): + logger.info("Here is an example with api_contract.name as weather_api.get_current_weather") + logger.info(self.input_example) # Pretty print with 4-space indentation + def _get_validated_model(self, args, kwargs): """Convert input arguments into a validated Pydantic model instance.""" # Case 1: User passed a model instance directly @@ -221,14 +245,16 @@ def get(template_name: str) -> Template: return data_gen_template._template_instance_registry[template_name] @staticmethod - def register(name: str, input_schema: type, output_schema: type, description: str, author: str, starfish_version: str, dependencies: list): + def register( + name: str, input_schema: type, input_example: str, output_schema: type, description: str, author: str, starfish_version: str, dependencies: list + ): """Decorator factory for registering data templates.""" def decorator(func: callable): # Check if this is an import call (function already has _is_template flag) if name not in data_gen_template._template_instance_registry: data_gen_template._template_instance_registry[name] = Template( - name, func, input_schema, output_schema, description, author, starfish_version, dependencies + name, func, input_schema, input_example, output_schema, description, author, starfish_version, dependencies ) return func diff --git a/src/starfish/data_gen_template/templates/starfish/function_calling/generator.py b/src/starfish/data_gen_template/templates/starfish/function_calling/generator.py index 837dd86..325169d 100644 --- a/src/starfish/data_gen_template/templates/starfish/function_calling/generator.py +++ b/src/starfish/data_gen_template/templates/starfish/function_calling/generator.py @@ -67,6 +67,22 @@ class GenerateFuncCallDataSet(BaseModel): author="Wendao Liu", starfish_version="0.1.3", dependencies=[], + input_example="""{ + "num_records": 4, + "api_contract": { + "name": "weather_api.get_current_weather", + "description": "Retrieves the current weather conditions for a specified location .", + "parameters": { + "location": {"type": "string", "description": "The name of the city or geographic location .", "required": True}, + "units": {"type": "string", "description": "The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .", "required": False}, + }, + }, + "topic_model_name": "openai/gpt-4", + "topic_model_kwargs": {"temperature": 0.7}, + "generation_model_name": "openai/gpt-4o-mini", + "generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, + "data_factory_config": {"max_concurrency": 24, "task_runner_timeout": 60 * 2}, + }""", ) async def api_contract_workflow(input_data: GenerateFuncCallDataSet): api_contract = input_data.api_contract.model_dump() diff --git a/src/starfish/data_gen_template/templates/starfish/generate_by_topic/generator.py b/src/starfish/data_gen_template/templates/starfish/generate_by_topic/generator.py index 4f97c51..a13d059 100644 --- a/src/starfish/data_gen_template/templates/starfish/generate_by_topic/generator.py +++ b/src/starfish/data_gen_template/templates/starfish/generate_by_topic/generator.py @@ -41,6 +41,27 @@ class GenerateByTopicInput(BaseModel): author="Wendao Liu", starfish_version="0.1.3", dependencies=[], + input_example="""{ + "user_instruction": "Generate Q&A pairs about machine learning concepts", + "num_records": 100, + "records_per_topic": 5, + "topics": [ + "supervised learning", + "unsupervised learning", + {"reinforcement learning": 3}, # This means generate 3 records for this topic + "neural networks", + ], + "topic_model_name": "openai/gpt-4", + "topic_model_kwargs": {"temperature": 0.7}, + "generation_model_name": "openai/gpt-4", + "generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, + "output_schema": [ + {"name": "question", "type": "str"}, + {"name": "answer", "type": "str"}, + {"name": "difficulty", "type": "str"}, # Added an additional field + ], + "data_factory_config": {"max_concurrency": 4, "task_runner_timeout": 60 * 2}, + }""", ) async def generate_by_topic(input_data: GenerateByTopicInput): """ diff --git a/tests/data_factory/factory/test_resume_index_1.ipynb b/tests/data_factory/factory/test_resume_index_1.ipynb deleted file mode 100644 index 5691362..0000000 --- a/tests/data_factory/factory/test_resume_index_1.ipynb +++ /dev/null @@ -1,570 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --index-url https://test.pypi.org/simple/ \\\n", - " --extra-index-url https://pypi.org/simple \\\n", - " starfish-core==0.1.3" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import asyncio\n", - "import time\n", - "import signal\n", - "import threading\n", - "from typing import List, Dict, Any, Optional, Set, Tuple\n", - "\n", - "from starfish import data_factory\n", - "from starfish.common.env_loader import load_env_file\n", - "from starfish.data_factory.utils.mock import mock_llm_call\n", - "from starfish.common.logger import get_logger\n", - "logger = get_logger(__name__)\n", - "# Load environment variables\n", - "load_env_file()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Apply nest_asyncio for use in Jupyter notebooks\n", - "try:\n", - " import nest_asyncio\n", - " nest_asyncio.apply()\n", - "except ImportError:\n", - " print(\"nest_asyncio not found, skipping. This may cause issues if run in a notebook.\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Define the data factory function at module level\n", - "@data_factory(max_concurrency=10)\n", - "async def mock_llm_processor(city_name: str, num_records_per_city: int):\n", - " \"\"\"Mock LLM processor that simulates processing with a delay\"\"\"\n", - " # Added sleep to make the process take longer for demonstration\n", - " await asyncio.sleep(0.5)\n", - " return await mock_llm_call(city_name=city_name, num_records_per_city=num_records_per_city, fail_rate=0)\n", - "\n", - "class TestRunner:\n", - " def __init__(self, total_time_limit=15, checkpoint_interval=3, max_checkpoints=10):\n", - " \"\"\"\n", - " Initialize the test runner\n", - " \n", - " Args:\n", - " total_time_limit: Maximum time allowed for the whole test in seconds\n", - " checkpoint_interval: Time between checkpoints (stop/resume) in seconds\n", - " max_checkpoints: Maximum number of checkpoints before forced termination\n", - " \"\"\"\n", - " self.total_time_limit = total_time_limit\n", - " self.checkpoint_interval = checkpoint_interval\n", - " self.max_checkpoints = max_checkpoints\n", - " self.errors = [] # Each item will be a tuple of (step, error_message)\n", - " self.results = None\n", - " self.stop_events = []\n", - " self.job = None\n", - " self.timeout_triggered = False\n", - " self.all_checkpoint_errors = {} # Dictionary to track errors per checkpoint\n", - " \n", - " def add_error(self, step: str, error_message: str):\n", - " \"\"\"Add an error with the associated step information\"\"\"\n", - " self.errors.append((step, error_message))\n", - " \n", - " # Also track errors by checkpoint\n", - " if step not in self.all_checkpoint_errors:\n", - " self.all_checkpoint_errors[step] = []\n", - " self.all_checkpoint_errors[step].append(error_message)\n", - " \n", - " def validate_completion_indices(self, indices, step_name=\"Validation\", is_final=False):\n", - " \"\"\"\n", - " Validate that the completion indices are correct\n", - " \n", - " Args:\n", - " indices: The indices to validate\n", - " step_name: The name of the step for error reporting\n", - " is_final: Whether this is the final validation (expecting all indices to be complete)\n", - " \n", - " Returns:\n", - " List of errors found\n", - " \"\"\"\n", - " errors = []\n", - " \n", - " # Safety check for None\n", - " if indices is None:\n", - " error = \"Indices are None\"\n", - " self.add_error(step_name, error)\n", - " return [error]\n", - " \n", - " # Get the completed count\n", - " completed_values = [idx for idx in indices if idx is not None]\n", - " completed_count = len(completed_values)\n", - " \n", - " # For final validation, check if all indices are completed\n", - " if is_final:\n", - " # Check length\n", - " if len(indices) != 100:\n", - " error = f\"Expected 100 indices total, but found {len(indices)}\"\n", - " self.add_error(step_name, error)\n", - " errors.append(error)\n", - " \n", - " # Check that all are completed (no None values)\n", - " if completed_count != 100:\n", - " error = f\"Expected 100 completed indices, but found {completed_count}\"\n", - " self.add_error(step_name, error)\n", - " errors.append(error)\n", - " \n", - " # Check for uniqueness among completed indices (always important)\n", - " unique_indices = set(completed_values)\n", - " if len(unique_indices) != len(completed_values):\n", - " duplicates = [idx for idx in unique_indices if indices.count(idx) > 1]\n", - " error = f\"Found duplicate values: {duplicates}\"\n", - " self.add_error(step_name, error)\n", - " errors.append(error)\n", - " \n", - " # Check range of indices (0-99)\n", - " expected_range = set(range(100))\n", - " extra = unique_indices - expected_range\n", - " \n", - " if extra:\n", - " error = f\"Unexpected indices: {sorted(extra)}\"\n", - " self.add_error(step_name, error)\n", - " errors.append(error)\n", - " \n", - " # For final validation, check if any indices are missing\n", - " if is_final:\n", - " missing = expected_range - unique_indices\n", - " if missing:\n", - " error = f\"Missing indices: {sorted(missing)}\"\n", - " self.add_error(step_name, error)\n", - " errors.append(error)\n", - " \n", - " return errors\n", - "\n", - " def interrupt_execution(self):\n", - " \"\"\"Schedule an interruption after the checkpoint interval\"\"\"\n", - " print(f\"⏱️ Scheduling interruption in {self.checkpoint_interval} seconds\")\n", - " timer = threading.Timer(self.checkpoint_interval, self.raise_interrupt)\n", - " self.stop_events.append(timer)\n", - " timer.start()\n", - "\n", - " def raise_interrupt(self):\n", - " \"\"\"Raise a KeyboardInterrupt to stop the execution\"\"\"\n", - " print(\"πŸ›‘ Raising interruption signal\")\n", - " signal.raise_signal(signal.SIGINT)\n", - "\n", - " def setup_timeout(self):\n", - " \"\"\"Set up the overall timeout for the test\"\"\"\n", - " print(f\"⏱️ Setting up timeout limit of {self.total_time_limit} seconds\")\n", - " timeout_timer = threading.Timer(self.total_time_limit, self.handle_timeout)\n", - " self.stop_events.append(timeout_timer)\n", - " timeout_timer.start()\n", - "\n", - " def handle_timeout(self):\n", - " \"\"\"Handle the timeout by setting a flag instead of forcefully exiting\"\"\"\n", - " print(\"⏰ Timeout reached! Stopping the job gracefully.\")\n", - " self.add_error(\"Timeout\", f\"Test exceeded maximum time limit of {self.total_time_limit} seconds\")\n", - " # Set a flag instead of hard exiting - this is more Jupyter-friendly\n", - " self.timeout_triggered = True\n", - " # Signal the main thread to stop\n", - " signal.raise_signal(signal.SIGINT)\n", - "\n", - " def cleanup_timers(self):\n", - " \"\"\"Clean up all running timers\"\"\"\n", - " for timer in self.stop_events:\n", - " if timer.is_alive():\n", - " timer.cancel()\n", - " self.stop_events = []\n", - " \n", - " def check_progress_and_validate(self, checkpoint_name):\n", - " \"\"\"\n", - " Check the current progress and validate indices for the current checkpoint\n", - " \n", - " Returns:\n", - " Tuple of (progress_info, completed)\n", - " \"\"\"\n", - " progress_info = \"Unknown\"\n", - " completed = False\n", - " \n", - " try:\n", - " # Safely get job status - avoid calling methods directly on potentially None objects\n", - " if hasattr(self.job, 'get_index_completed') and callable(getattr(self.job, 'get_index_completed')):\n", - " indices = self.job.get_index_completed()\n", - " \n", - " # Safety check\n", - " if indices is not None:\n", - " # Determine if this is final validation based on completion status\n", - " completed_count = len([i for i in indices if i is not None])\n", - " is_final = completed_count == 100\n", - " \n", - " # Perform validation for this checkpoint\n", - " validation_errors = self.validate_completion_indices(\n", - " indices, \n", - " checkpoint_name + \" Validation\",\n", - " is_final=is_final\n", - " )\n", - " if validation_errors:\n", - " print(f\"❌ {checkpoint_name} validation failed:\")\n", - " for err in validation_errors:\n", - " print(f\" - {err}\")\n", - " elif is_final:\n", - " print(f\"βœ… {checkpoint_name} validation passed: All indices are correct\")\n", - " else:\n", - " print(f\"βœ… {checkpoint_name} partial validation passed: {completed_count} indices processed\")\n", - " \n", - " progress_info = f\"{completed_count}/100\"\n", - " \n", - " # Check if all tasks are completed\n", - " if completed_count == 100:\n", - " completed = True\n", - " else:\n", - " self.add_error(checkpoint_name, \"Failed to get indices: indices is None\")\n", - " print(f\"⚠️ {checkpoint_name}: Failed to get indices: indices is None\")\n", - " else:\n", - " self.add_error(checkpoint_name, \"Job does not have get_index_completed method\")\n", - " print(f\"⚠️ {checkpoint_name}: Job does not have get_index_completed method\")\n", - " \n", - " except Exception as e:\n", - " self.add_error(checkpoint_name, f\"Error getting indices: {str(e)}\")\n", - " print(f\"❌ {checkpoint_name}: Error getting indices: {str(e)}\")\n", - " \n", - " return progress_info, completed\n", - "\n", - " def _finish_test(self, start_time):\n", - " \"\"\"Finish the test by cleaning up and returning results\"\"\"\n", - " # Clean up timers\n", - " self.cleanup_timers()\n", - " \n", - " # Final validation if we have a job\n", - " if self.job and hasattr(self.job, 'get_index_completed'):\n", - " try:\n", - " final_indices = self.job.get_index_completed()\n", - " # Always perform full validation in the final step\n", - " validation_errors = self.validate_completion_indices(final_indices, \"Final Validation\", is_final=True)\n", - " if validation_errors:\n", - " print(\"❌ Final validation failed:\")\n", - " for err in validation_errors:\n", - " print(f\" - {err}\")\n", - " else:\n", - " print(\"βœ… Final validation passed: All indices are correct\")\n", - " except Exception as e:\n", - " self.add_error(\"Final Validation\", f\"Error getting final indices: {str(e)}\")\n", - " print(f\"❌ Error in final validation: {str(e)}\")\n", - "\n", - " def run_test(self):\n", - " \"\"\"Run the complete test with interruptions and resumptions\"\"\"\n", - " # Create input data\n", - " cities = [\"New York\", \"London\", \"Tokyo\", \"Paris\", \"Sydney\"] * 20 # 100 cities\n", - " \n", - " print(\"=== Starting Initial Run ===\")\n", - " start_time = time.time()\n", - " \n", - " try:\n", - " # Setup timers\n", - " self.setup_timeout()\n", - " self.interrupt_execution()\n", - " logger.info(\"start to setup the job\")\n", - " # Start initial run - use the module level decorated function\n", - " self.job = mock_llm_processor # Use the module-level function\n", - " logger.info(\"finish to setup the job\")\n", - " logger.info(self.job.get_index_completed)\n", - " try:\n", - " self.results = self.job.run(city_name=cities, num_records_per_city=1)\n", - " print(\"βœ… Initial run completed without interruption\")\n", - " \n", - " # Check progress and validate after initial run\n", - " progress_info, completed = self.check_progress_and_validate(\"Initial Run\")\n", - " if completed:\n", - " print(\"βœ… All tasks completed in initial run\")\n", - " return self._finish_test(start_time)\n", - " \n", - " except Exception as e:\n", - " self.add_error(\"Initial Run\", f\"Error: {str(e)}\")\n", - " print(f\"❌ Error in Initial Run: {str(e)}\")\n", - " # Don't return here, continue with checkpoint attempts\n", - " except KeyboardInterrupt:\n", - " print(\"⚠️ Initial run interrupted\")\n", - " \n", - " # Check progress and validate after interruption\n", - " progress_info, completed = self.check_progress_and_validate(\"Initial Run (Interrupted)\")\n", - " if completed:\n", - " print(\"βœ… All tasks completed after initial interruption\")\n", - " return self._finish_test(start_time)\n", - " \n", - " except Exception as e:\n", - " self.add_error(\"Initial Run Setup\", f\"Error: {str(e)}\")\n", - " print(f\"❌ Error in Initial Run setup: {str(e)}\")\n", - " # Don't return here, continue with checkpoint attempts\n", - " \n", - " # Resume until complete\n", - " checkpoint_count = 1\n", - " \n", - " # Add a safety counter to prevent infinite loops\n", - " while checkpoint_count <= self.max_checkpoints:\n", - " checkpoint_name = f\"Checkpoint {checkpoint_count}\"\n", - " \n", - " # Check if timeout was triggered\n", - " if self.timeout_triggered:\n", - " print(\"⏰ Test timed out - stopping testing loop\")\n", - " break\n", - " \n", - " # Check if we have reached the total time limit\n", - " if time.time() - start_time >= self.total_time_limit:\n", - " self.add_error(checkpoint_name, f\"Test exceeded maximum time limit of {self.total_time_limit} seconds\")\n", - " print(f\"⏰ Test timed out after {self.total_time_limit} seconds\")\n", - " break\n", - " \n", - " # Check if we've hit the max checkpoint count\n", - " if checkpoint_count == self.max_checkpoints:\n", - " self.add_error(checkpoint_name, f\"Test reached maximum checkpoint count of {self.max_checkpoints}\")\n", - " print(f\"⚠️ Test reached maximum checkpoint count of {self.max_checkpoints}\")\n", - " break\n", - " \n", - " # Check if we have a job to resume\n", - " if self.job is None:\n", - " self.add_error(checkpoint_name, \"Cannot continue: job is None\")\n", - " print(\"❌ Cannot continue: job is None\")\n", - " break\n", - " \n", - " # Check progress before resuming\n", - " progress_info, completed = self.check_progress_and_validate(f\"Before {checkpoint_name}\")\n", - " if completed:\n", - " print(f\"βœ… All tasks completed before {checkpoint_name}\")\n", - " break\n", - " \n", - " print(f\"=== Starting {checkpoint_name} ({progress_info}) ===\")\n", - " \n", - " # Resume the job\n", - " try:\n", - " # Setup interruption for the next checkpoint\n", - " self.interrupt_execution()\n", - " \n", - " # Try to resume if the method exists\n", - " if hasattr(self.job, 'resume') and callable(getattr(self.job, 'resume')):\n", - " try:\n", - " self.results = self.job.resume()\n", - " print(f\"βœ… {checkpoint_name} completed without interruption\")\n", - " \n", - " # Check progress after resumption\n", - " progress_info, completed = self.check_progress_and_validate(f\"After {checkpoint_name}\")\n", - " if completed:\n", - " print(f\"βœ… All tasks completed after {checkpoint_name}\")\n", - " break\n", - " \n", - " except Exception as e:\n", - " self.add_error(checkpoint_name, f\"Error: {str(e)}\")\n", - " print(f\"❌ Error in {checkpoint_name}: {str(e)}\")\n", - " # Continue to the next checkpoint\n", - " else:\n", - " self.add_error(checkpoint_name, \"Job does not have resume method\")\n", - " print(\"⚠️ Job does not have resume method\")\n", - " break # Can't continue without resume method\n", - " \n", - " except KeyboardInterrupt:\n", - " print(f\"⚠️ {checkpoint_name} interrupted\")\n", - " \n", - " # Check progress after interruption\n", - " progress_info, completed = self.check_progress_and_validate(f\"After {checkpoint_name} (Interrupted)\")\n", - " if completed:\n", - " print(f\"βœ… All tasks completed after {checkpoint_name} interruption\")\n", - " break\n", - " \n", - " checkpoint_count += 1\n", - "\n", - " # Finish the test\n", - " return self._finish_test(start_time)\n", - " \n", - " def _finish_test(self, start_time):\n", - " \"\"\"Finish the test by cleaning up and returning results\"\"\"\n", - " # Clean up timers\n", - " self.cleanup_timers()\n", - " \n", - " # Final validation if we have a job\n", - " if self.job and hasattr(self.job, 'get_index_completed'):\n", - " try:\n", - " final_indices = self.job.get_index_completed()\n", - " validation_errors = self.validate_completion_indices(final_indices, \"Final Validation\")\n", - " if validation_errors:\n", - " print(\"❌ Final validation failed:\")\n", - " for err in validation_errors:\n", - " print(f\" - {err}\")\n", - " else:\n", - " print(\"βœ… Final validation passed: All indices are correct\")\n", - " except Exception as e:\n", - " self.add_error(\"Final Validation\", f\"Error getting final indices: {str(e)}\")\n", - " print(f\"❌ Error in final validation: {str(e)}\")\n", - " \n", - " # Report final status\n", - " total_time = time.time() - start_time\n", - " print(f\"\\n=== Test Summary ===\")\n", - " print(f\"Total execution time: {total_time:.2f} seconds\")\n", - " \n", - " # Group errors by phase type for summary\n", - " validation_phases = [p for p in self.all_checkpoint_errors.keys() if \"Validation\" in p]\n", - " checkpoint_phases = [p for p in self.all_checkpoint_errors.keys() if \"Checkpoint\" in p and \"Validation\" not in p]\n", - " timeout_phases = [p for p in self.all_checkpoint_errors.keys() if \"Timeout\" in p]\n", - " other_phases = [p for p in self.all_checkpoint_errors.keys() \n", - " if p not in validation_phases and p not in checkpoint_phases and p not in timeout_phases]\n", - " \n", - " # Report errors by category\n", - " print(\"\\n=== Errors by Phase ===\")\n", - " \n", - " # Show timeout errors first\n", - " if timeout_phases:\n", - " print(\"\\nTimeout Errors:\")\n", - " for phase in timeout_phases:\n", - " for err in self.all_checkpoint_errors[phase]:\n", - " print(f\" - {err}\")\n", - " \n", - " # Show checkpoint execution errors\n", - " if checkpoint_phases:\n", - " print(\"\\nCheckpoint Execution Errors:\")\n", - " for phase in sorted(checkpoint_phases):\n", - " if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:\n", - " print(f\" {phase}:\")\n", - " for err in self.all_checkpoint_errors[phase]:\n", - " print(f\" - {err}\")\n", - " \n", - " # Show validation errors for each checkpoint\n", - " if validation_phases:\n", - " print(\"\\nValidation Errors:\")\n", - " for phase in sorted(validation_phases):\n", - " if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:\n", - " print(f\" {phase}:\")\n", - " for err in self.all_checkpoint_errors[phase]:\n", - " print(f\" - {err}\")\n", - " \n", - " # Show other errors\n", - " if other_phases:\n", - " print(\"\\nOther Errors:\")\n", - " for phase in sorted(other_phases):\n", - " if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:\n", - " print(f\" {phase}:\")\n", - " for err in self.all_checkpoint_errors[phase]:\n", - " print(f\" - {err}\")\n", - " \n", - " if not self.errors:\n", - " print(\"\\nβœ… Test completed successfully with no errors\")\n", - " else:\n", - " validation_error_count = sum(len(self.all_checkpoint_errors[p]) for p in validation_phases if p in self.all_checkpoint_errors)\n", - " checkpoint_error_count = sum(len(self.all_checkpoint_errors[p]) for p in checkpoint_phases if p in self.all_checkpoint_errors)\n", - " timeout_error_count = sum(len(self.all_checkpoint_errors[p]) for p in timeout_phases if p in self.all_checkpoint_errors)\n", - " other_error_count = sum(len(self.all_checkpoint_errors[p]) for p in other_phases if p in self.all_checkpoint_errors)\n", - " \n", - " print(f\"\\n❌ Test completed with {len(self.errors)} errors:\")\n", - " print(f\" - {timeout_error_count} timeout errors\")\n", - " print(f\" - {checkpoint_error_count} checkpoint execution errors\")\n", - " print(f\" - {validation_error_count} validation errors\")\n", - " print(f\" - {other_error_count} other errors\")\n", - " \n", - " return {\n", - " \"success\": len(self.errors) == 0,\n", - " \"errors\": self.errors,\n", - " \"errors_by_checkpoint\": self.all_checkpoint_errors,\n", - " \"total_time\": total_time,\n", - " \"results\": self.results\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Run the test\n", - "runner = TestRunner(total_time_limit=20, checkpoint_interval=3, max_checkpoints=10)\n", - "result = runner.run_test()\n", - "if not result[\"success\"]:\n", - " # Format error message to include all errors organized by category\n", - " error_parts = []\n", - " \n", - " # Categorize phases\n", - " validation_phases = [p for p in result[\"errors_by_checkpoint\"].keys() if \"Validation\" in p]\n", - " checkpoint_phases = [p for p in result[\"errors_by_checkpoint\"].keys() if \"Checkpoint\" in p and \"Validation\" not in p]\n", - " timeout_phases = [p for p in result[\"errors_by_checkpoint\"].keys() if \"Timeout\" in p]\n", - " other_phases = [p for p in result[\"errors_by_checkpoint\"].keys() \n", - " if p not in validation_phases and p not in checkpoint_phases and p not in timeout_phases]\n", - " \n", - " # Add timeout errors first\n", - " if timeout_phases:\n", - " error_parts.append(\"\\n=== TIMEOUT ERRORS ===\")\n", - " for phase in timeout_phases:\n", - " for err in result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"- {err}\")\n", - " \n", - " # Add checkpoint execution errors\n", - " if checkpoint_phases:\n", - " error_parts.append(\"\\n=== CHECKPOINT EXECUTION ERRORS ===\")\n", - " for phase in sorted(checkpoint_phases):\n", - " if phase in result[\"errors_by_checkpoint\"] and result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"\\n-- {phase} --\")\n", - " for err in result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"- {err}\")\n", - " \n", - " # Add validation errors for each checkpoint\n", - " if validation_phases:\n", - " error_parts.append(\"\\n=== VALIDATION ERRORS ===\")\n", - " for phase in sorted(validation_phases):\n", - " if phase in result[\"errors_by_checkpoint\"] and result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"\\n-- {phase} --\")\n", - " for err in result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"- {err}\")\n", - " \n", - " # Add other errors\n", - " if other_phases:\n", - " error_parts.append(\"\\n=== OTHER ERRORS ===\")\n", - " for phase in sorted(other_phases):\n", - " if phase in result[\"errors_by_checkpoint\"] and result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"\\n-- {phase} --\")\n", - " for err in result[\"errors_by_checkpoint\"][phase]:\n", - " error_parts.append(f\"- {err}\")\n", - " \n", - " error_message = \"\\n\".join(error_parts)\n", - " validation_error_count = sum(len(result[\"errors_by_checkpoint\"][p]) for p in validation_phases if p in result[\"errors_by_checkpoint\"])\n", - " checkpoint_error_count = sum(len(result[\"errors_by_checkpoint\"][p]) for p in checkpoint_phases if p in result[\"errors_by_checkpoint\"])\n", - " timeout_error_count = sum(len(result[\"errors_by_checkpoint\"][p]) for p in timeout_phases if p in result[\"errors_by_checkpoint\"])\n", - " other_error_count = sum(len(result[\"errors_by_checkpoint\"][p]) for p in other_phases if p in result[\"errors_by_checkpoint\"])\n", - " \n", - " raise RuntimeError(f\"Test failed with {len(result['errors'])} total errors ({timeout_error_count} timeout, {checkpoint_error_count} execution, {validation_error_count} validation, {other_error_count} other):{error_message}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "starfish-core-T7IInzTH-py3.11", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/data_template/test_data_template.py b/tests/data_template/test_data_template.py index 855f35a..fd94f29 100644 --- a/tests/data_template/test_data_template.py +++ b/tests/data_template/test_data_template.py @@ -2,7 +2,7 @@ import pytest import os from starfish.common.env_loader import load_env_file -from starfish.data_template.template_gen import data_gen_template +from starfish import data_gen_template nest_asyncio.apply() load_env_file() diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index f3d9ce3..f34830a 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -13,7 +13,7 @@ def get_notebooks(base_dir=None): """Find all test notebooks in the project directory.""" if base_dir is None: - base_dir = Path(__file__).parent.parent + base_dir = Path(__file__).parent else: base_dir = Path(base_dir) @@ -22,6 +22,9 @@ def get_notebooks(base_dir=None): # Skip checkpoints if ".ipynb_checkpoints" in str(nb_path): continue + # Skip specific notebook + if "data_factory.ipynb" in str(nb_path): + continue # Only include notebooks that follow test naming convention if nb_path.name.startswith("test_"): notebooks.append(str(nb_path)) @@ -35,6 +38,9 @@ def test_notebook_execution(notebook_file): """Run the notebook through pytest to verify it executes without errors.""" pytest.importorskip("nbval") + if "data_factory.ipynb" in notebook_file: + pytest.skip("Skipping data_factory.ipynb as it is excluded from testing") + # This test will be collected by pytest # We just need to ensure the file exists assert os.path.exists(notebook_file), f"Notebook file not found: {notebook_file}"