diff --git a/src/starfish/components/prepare_topic.py b/src/starfish/components/prepare_topic.py index 354057c..42e6814 100644 --- a/src/starfish/components/prepare_topic.py +++ b/src/starfish/components/prepare_topic.py @@ -6,7 +6,7 @@ async def generate_topics( - user_instructions: str, + user_instruction: str, num_topics: int, model_name: str = "openai/gpt-4o-mini", model_kwargs: Optional[Dict[str, Any]] = None, @@ -30,7 +30,7 @@ async def generate_topics( for _ in range(num_batches): topic_generator = StructuredLLM( model_name=model_name, - prompt="""Can you generate a list of topics about {{user_instructions}} + prompt="""Can you generate a list of topics about {{user_instruction}} {% if existing_topics_str %} Please do not generate topics that are already in the list: {{existing_topics_str}} Make sure the topics are unique and vary from each other @@ -41,7 +41,7 @@ async def generate_topics( ) all_existing = existing_topics + generated_topics - input_params = {"user_instructions": user_instructions, "num_records": min(llm_batch_size, num_topics - len(generated_topics))} + input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))} if all_existing: input_params["existing_topics_str"] = ",".join(all_existing) @@ -60,7 +60,7 @@ async def prepare_topic( topics: Optional[List[Union[str, Dict[str, int]]]] = None, num_records: Optional[int] = None, records_per_topic: int = 20, - user_instructions: Optional[str] = None, + user_instruction: Optional[str] = None, model_name: str = "openai/gpt-4o-mini", model_kwargs: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, str]]: @@ -70,13 +70,13 @@ async def prepare_topic( 1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution 2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts 3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats - 4. None: No topics provided, will generate based on user_instructions + 4. None: No topics provided, will generate based on user_instruction Args: topics: Optional list of topics, either strings or {topic: count} dicts num_records: Total number of records to split (required for dict topics or None topics) records_per_topic: Number of records per topic (default: 20) - user_instructions: Topic generation instructions (required if topics is None) + user_instruction: Topic generation instructions (required if topics is None) model_name: Model name for topic generation model_kwargs: Model kwargs for topic generation @@ -89,11 +89,11 @@ async def prepare_topic( model_kwargs["temperature"] = 1 # --- STEP 1: Input validation and normalization --- if topics is None: - # Must have num_records and user_instructions if no topics provided + # Must have num_records and user_instruction if no topics provided if not num_records or num_records <= 0: raise ValueError("num_records must be positive when topics are not provided") - if not user_instructions: - raise ValueError("user_instructions required when topics are not provided") + if not user_instruction: + raise ValueError("user_instruction required when topics are not provided") topic_assignments = [] else: # Validate topics is a non-empty list @@ -181,11 +181,11 @@ async def prepare_topic( raise ValueError("records_per_topic must be positive when generating topics") # Generate topics with LLM if instructions provided - if user_instructions: + if user_instruction: topics_needed = math.ceil(remaining_records / records_per_topic) generated = await generate_topics( - user_instructions=user_instructions, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names + user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names ) # Assign counts to generated topics @@ -237,7 +237,7 @@ async def prepare_topic( # Example 1: Dictionary topics with additional generation print("\nExample 1: Dictionary topics + generation") topics1 = [{"topic1": 20}, {"topic2": 30}] - result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instructions="some context")) + result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context")) print(f"Result: {result1}") print(f"Total: {len(result1)}") @@ -251,7 +251,7 @@ async def prepare_topic( # Example 3: Mixed string and dict topics print("\nExample 3: Mixed string/dict topics") topics3 = ["topicX", {"topicY": 10}] - result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instructions="mixed topics")) + result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics")) print(f"Result: {result3}") print(f"Total: {len(result3)}") @@ -266,7 +266,7 @@ async def prepare_topic( print("\nExample 5: No topics, generate all") async def run_example5(): - result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instructions="cloud computing") + result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing") print(f"Result: {result}") print(f"Total: {len(result)}") diff --git a/src/starfish/data_factory/utils/data_class.py b/src/starfish/data_factory/utils/data_class.py index a2a4513..9d17ee5 100644 --- a/src/starfish/data_factory/utils/data_class.py +++ b/src/starfish/data_factory/utils/data_class.py @@ -122,6 +122,45 @@ def to_json(self): return json.dumps(self.to_dict(), indent=2) + def update(self, data: dict): + """Update the configuration from a dictionary. + + This method updates the configuration fields from a dictionary. It handles + the deserialization of callable functions using cloudpickle. + + Args: + data (dict): Dictionary containing the configuration data. Expected keys: + - storage: Storage type string + - master_job_id: Unique job identifier + - project_id: Project identifier + - batch_size: Number of records per batch + - target_count: Total records to process + - max_concurrency: Maximum concurrent tasks + - show_progress: Whether to show progress + - task_runner_timeout: Task timeout in seconds + - on_record_complete: List of callable strings for record completion + - on_record_error: List of callable strings for record errors + - run_mode: Execution mode string + - job_run_stop_threshold: Job retry threshold + + Raises: + ValueError: If invalid fields are provided + cloudpickle.PickleError: If callable deserialization fails + """ + if not isinstance(data, dict): + raise ValueError("Input must be a dictionary") + + # Handle callable deserialization + if "on_record_complete" in data: + self.on_record_complete = [cloudpickle.loads(bytes.fromhex(c)) if c else None for c in data["on_record_complete"]] + if "on_record_error" in data: + self.on_record_error = [cloudpickle.loads(bytes.fromhex(c)) if c else None for c in data["on_record_error"]] + + # Update other fields + for key, value in data.items(): + if key not in ["on_record_complete", "on_record_error"] and hasattr(self, key): + setattr(self, key, value) + @dataclass class FactoryJobConfig: diff --git a/src/starfish/data_factory/utils/state.py b/src/starfish/data_factory/utils/state.py index 771a554..c93faa2 100644 --- a/src/starfish/data_factory/utils/state.py +++ b/src/starfish/data_factory/utils/state.py @@ -22,7 +22,7 @@ def __init__(self, initial_data: Optional[Dict[str, Any]] = None): the state will be initialized with a copy of this dictionary. """ super().__init__() - self._lock = threading.Lock() # Changed to threading.Lock + self._lock = threading.RLock() # Changed to threading.Lock if initial_data is not None: self._data = initial_data.copy() diff --git a/src/starfish/data_template/template_gen.py b/src/starfish/data_template/template_gen.py index 2032823..929c5f5 100644 --- a/src/starfish/data_template/template_gen.py +++ b/src/starfish/data_template/template_gen.py @@ -5,6 +5,7 @@ import importlib.util import ast from typing import Any, Union, List, Dict +import inspect from starfish.data_template.utils.error import DataTemplateValueError, ImportModuleError, ImportPackageError @@ -53,41 +54,88 @@ def __init__( # Add run method if not present if not hasattr(self.func, "run"): self.func.run = lambda *args, **kwargs: self.func(*args, **kwargs) - if not self.input_schema: - raise DataTemplateValueError("missing input_schema_validator") + + # Detect if function expects a single input_data parameter + self.expects_model_input = False + try: + self.func_signature = inspect.signature(self.func) + + # Check if the function has a single parameter that matches the input_schema type + if len(self.func_signature.parameters) == 1: + param_name, param = next(iter(self.func_signature.parameters.items())) + # Check if the parameter has an annotation matching input_schema + if param.annotation == self.input_schema: + self.expects_model_input = True + except (TypeError, ValueError): + # Can't get signature (e.g., for data_factory decorated functions) + # Default to False (use individual parameters) which is safer + self.expects_model_input = False # Check dependencies on initialization # if self.dependencies: # _check_dependencies(self.dependencies) - def run(self, *args, **kwargs) -> Any: - """Execute the wrapped function with schema validation.""" - # Pre-run hook: Validate input schema - try: - # Validate input against schema - if self.input_schema: - if args: - self.input_schema.validate(args[0]) - elif kwargs: - self.input_schema.validate(kwargs) - except pydantic.ValidationError as e: - raise DataTemplateValueError(f"Input validation failed: {str(e)}") + async def run(self, *args, **kwargs) -> Any: + """Execute the wrapped function with schema validation. + + This method supports multiple calling patterns: + - template.run(param1=val1, param2=val2) # Keyword arguments + - template.run({"param1": val1, "param2": val2}) # Dictionary + - template.run(model_instance) # Pydantic model instance - # Execute the function + The template function can be defined in two ways: + 1. Taking a single Pydantic model parameter: func(input_data: Model) + 2. Taking individual parameters: func(param1, param2, param3) + + In all cases, validation happens through the Pydantic model. + """ + # STEP 1: Get a validated Pydantic model instance from the inputs + validated_model = self._get_validated_model(args, kwargs) + + # STEP 2: Call the function with appropriate arguments based on its signature try: - result = self.func.run(*args, **kwargs) + if self.expects_model_input: + # Pass the whole model to functions expecting a model parameter + result = await self.func.run(validated_model) + else: + # Expand model fields for functions expecting individual parameters + result = await self.func.run(**validated_model.model_dump()) except Exception as e: raise DataTemplateValueError(f"Template execution failed: {str(e)}") - # Post-run hook: Validate output schema + # STEP 3: Validate the output if an output schema is provided if self.output_schema is not None: try: - self.output_schema.validate(result) + # Use model_validate instead of validate (which is deprecated in Pydantic v2) + self.output_schema.model_validate(result) except pydantic.ValidationError as e: raise DataTemplateValueError(f"Output validation failed: {str(e)}") return result + def _get_validated_model(self, args, kwargs): + """Convert input arguments into a validated Pydantic model instance.""" + # Case 1: User passed a model instance directly + if len(args) == 1 and isinstance(args[0], self.input_schema): + return args[0] + + # Case 2: User passed a dictionary as the first positional argument + if len(args) == 1 and isinstance(args[0], dict): + # Merge dictionary with any keyword arguments + input_data = {**args[0], **kwargs} + # Case 3: User passed keyword arguments only + elif not args: + input_data = kwargs + # Case 4: Invalid input (multiple positional args or wrong type) + else: + raise DataTemplateValueError("Invalid arguments: Please provide either keyword arguments, " "a single dictionary, or a model instance.") + + # Validate and return a model instance + try: + return self.input_schema.model_validate(input_data) + except pydantic.ValidationError as e: + raise DataTemplateValueError(f"Input validation failed: {str(e)}") + # ==================== # Template Management Class @@ -107,11 +155,11 @@ def list(is_detail: bool = False) -> list[Any]: if len(result) == 0: # Walk through all subdirectories in templates folder for subdir in templates_dir.iterdir(): - if subdir.is_dir(): + for sub_subdir in subdir.iterdir(): # Find all .py files in the subdirectory - for template_file in subdir.glob("*.py"): + for template_file in sub_subdir.glob("*.py"): try: - module_name = f"starfish.data_template.templates.{subdir.name}.{template_file.stem}" + module_name = f"starfish.data_template.templates.{subdir.name}.{sub_subdir.name}.{template_file.stem}" # Parse the file's AST to extract decorator information with open(template_file, "r") as f: tree = ast.parse(f.read()) diff --git a/src/starfish/data_template/templates/community/topic_generator.py b/src/starfish/data_template/templates/community/topic_generator.py deleted file mode 100644 index b47c0b0..0000000 --- a/src/starfish/data_template/templates/community/topic_generator.py +++ /dev/null @@ -1,62 +0,0 @@ -from pydantic import BaseModel -from starfish import data_factory -from starfish.data_template.template_gen import data_gen_template - - -# Define input schema -class TopicGeneratorInput(BaseModel): - community_name: str - seed_topics: list[str] - num_topics: int - language: str = "en" - - -# Define output schema -class TopicGeneratorOutput(BaseModel): - generated_topics: list[str] - success: bool - message: str - - -@data_gen_template.register( - name="community/topic_generator", - input_schema=TopicGeneratorInput, - output_schema=TopicGeneratorOutput, - description="Generates relevant topics for community discussions using AI models", - author="Your Name", - starfish_version="0.1.0", - dependencies=["transformers_1>=4.0.0"], -) -# @data_factory(max_concurrency=10) -def topic_generator(input_data: TopicGeneratorInput) -> TopicGeneratorOutput: - try: - # Step 1: Generate initial topics - generated_topics = generate_initial_topics(input_data) - - # Step 2: Process topics in parallel - @data_factory(max_concurrency=10) - async def process_topics(topics: list[str]) -> list[str]: - return [refine_topic(topic) for topic in topics] - - refined_topics = process_topics.run(generated_topics) - - return TopicGeneratorOutput(generated_topics=refined_topics, success=True, message="Topics generated successfully") - except Exception as e: - return TopicGeneratorOutput(generated_topics=[], success=False, message=str(e)) - - -# Helper functions -def generate_initial_topics(input_data: TopicGeneratorInput) -> list[str]: - # Implement your topic generation logic here - # This could use AI models or other algorithms - # ... existing code ... - return ["Topic 1", "Topic 2", "Topic 3"] # Placeholder - - -def refine_topic(topic: str) -> str: - # Implement topic refinement logic here - # ... existing code ... - return topic.upper() # Placeholder - - -# print(result) diff --git a/src/starfish/data_template/templates/community/topic_generator_success.py b/src/starfish/data_template/templates/community/topic_generator_success.py deleted file mode 100644 index ef0ba3e..0000000 --- a/src/starfish/data_template/templates/community/topic_generator_success.py +++ /dev/null @@ -1,61 +0,0 @@ -from pydantic import BaseModel -from starfish import data_factory -from starfish.data_template.template_gen import data_gen_template - - -# Define input schema -class TopicGeneratorInput(BaseModel): - community_name: str - seed_topics: list[str] - num_topics: int - language: str = "en" - - -# Define output schema -class TopicGeneratorOutput(BaseModel): - generated_topics: list[str] - success: bool - message: str - - -@data_gen_template.register( - name="community/topic_generator_success", - input_schema=TopicGeneratorInput, - output_schema=TopicGeneratorOutput, - description="Generates relevant topics for community discussions using AI models", - author="Your Name", - starfish_version="0.1.0", - dependencies=["posthog>=3.11.0"], -) -def topic_generator(input_data: TopicGeneratorInput) -> TopicGeneratorOutput: - try: - # Step 1: Generate initial topics - generated_topics = generate_initial_topics(input_data) - - # Step 2: Process topics in parallel - @data_factory(max_concurrency=10) - async def process_topics(topics: list[str]) -> list[str]: - return [{"ans": refine_topic(topic) for topic in topics}] - - refined_topics = process_topics.run(topics=generated_topics) - - return TopicGeneratorOutput(generated_topics=[topic["ans"] for topic in refined_topics], success=True, message="Topics generated successfully") - except Exception as e: - return TopicGeneratorOutput(generated_topics=[], success=False, message=str(e)) - - -# Helper functions -def generate_initial_topics(input_data: TopicGeneratorInput) -> list[str]: - # Implement your topic generation logic here - # This could use AI models or other algorithms - # ... existing code ... - return [{"topic": "Topic 1"}, {"topic": "Topic 2"}, {"topic": "Topic 3"}] # Placeholder - - -def refine_topic(topic: str) -> str: - # Implement topic refinement logic here - # ... existing code ... - return topic.upper() # Placeholder - - -# print(result) diff --git a/src/starfish/data_template/templates/starfish/func_call/APIGen.md b/src/starfish/data_template/templates/starfish/func_call/APIGen.md new file mode 100644 index 0000000..576e820 --- /dev/null +++ b/src/starfish/data_template/templates/starfish/func_call/APIGen.md @@ -0,0 +1,127 @@ + +# APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets +**Zuxin Liu, Thai Hoang, Jianguo Zhang, Ming Zhu, Tian Lan, Shirley Kokane, Juntao Tan, Weiran Yao, Zhiwei Liu, Yihao Feng, Rithesh Murthy, Liangwei Yang, Silvio Savarese, Juan Carlos Niebles, Huan Wang, Shelby Heinecke, Caiming Xiong** +Salesforce AI Research, USA +{zuxin.liu, thai.hoang, jianguozhang}@salesforce.com + +--- + +## Abstract +APIGen is an automated pipeline for generating high-quality, verified function-calling datasets. It leverages **3,673 executable APIs** (3,539 REST APIs from ToolBench and 134 Python functions) across **21 categories**, ensuring diversity through structured JSON formatting, **multi-stage verification** (format, execution, semantic checks), and randomized sampling. Models trained with APIGen’s **60,000-entry dataset** achieve state-of-the-art performance: the **7B-parameter model** ranks 3rd (88.24% accuracy) on the Berkeley Function-Calling Leaderboard (BFCL), while the **1B model** surpasses GPT-3.5-Turbo. The dataset is publicly available on [Huggingface](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) and the [project homepage](https://apigen-pipeline.github.io/). + +--- + +## 1. Introduction +Function-calling agents enable LLMs to execute API calls based on natural language instructions (e.g., `get_weather("Palo Alto")`). However, existing datasets are static and lack verification. APIGen addresses this by: +- Generating **diverse, verified datasets** through LLM-driven synthesis. +- Implementing **three-stage verification** to ensure data quality. +- Supporting **parallel/multiple function calls** and complex scenarios. + +**Key Contributions**: +- APIGen framework for scalable, verifiable dataset generation. +- Two trained models: **xLAM-7B (FC)** (6th on BFCL) and **xLAM-1B (FC)** (outperforms GPT-3.5-Turbo). +- Public release of **60,000 high-quality entries**. + +--- + +## 2. APIGen Framework +### 2.1 Data Generation +1. **Sampling**: APIs and seed QA pairs are sampled from libraries. +2. **Prompt Templates**: Steer LLMs to generate query-answer pairs in JSON format. +3. **Multi-Stage Verification**: + - **Stage 1 (Format Checker)**: Validates JSON structure and required fields. + - **Stage 2 (Execution Checker)**: Executes function calls and checks for errors. + - **Stage 3 (Semantic Checker)**: LLM evaluates alignment between query intent and execution results. + +**JSON Advantages**: +- Structural verification of fields. +- Scalable integration of REST APIs/Python functions. + +### 2.2 Dataset Diversity +- **Query Styles**: Simple, Multiple, Parallel, Parallel-Multiple. +- **Sampling Strategies**: API Sampler, Example Sampler, Prompt Sampler. +- **API Sources**: 21 consolidated categories (e.g., Technology, Finance). + +--- + +## 3. Dataset Preparation +### 3.1 API Sources +- **ToolBench**: 16,464 REST APIs filtered to **3,539** executable APIs. +- **Python Functions**: 134 functions (math, finance, data management). + +**Filtering Steps**: +- Remove APIs with poor documentation or no parameters. +- Test API accessibility via endpoint requests. +- Regenerate noisy docstrings. + +### 3.2 Collection Setup +| **Model** | Verified Data | Pass Rate | +|--------------------------|---------------|-----------| +| DeepSeek-V2-Chat (236B) | 33,659 | 84.15% | +| Mixtral-8x22B-Inst | 26,384 | 65.96% | +| Mixtral-8x7B-Inst | 15,385 | 38.46% | + +--- + +## 4. Experiments +### 4.1 Benchmark Results (BFCL) +| **Model** | Overall Accuracy | Rank | +|----------------------|------------------|------| +| xLAM-7B (FC) | 88.68% | 6th | +| GPT-4-0125-Preview | 88.26% | 1st | +| xLAM-1B (FC) | 74.41% | 24th | +| GPT-3.5-Turbo-0125 | 63.88% | 33rd | + +**Key Findings**: +- APIGen-trained models excel in **parallel/multiple function calls**. +- Smaller models (1B) outperform larger ones (e.g., Claude-3 Haiku). + +### 4.2 Ablation Study +Including low-quality data (filtered by Stage 2/3) degrades performance: +- xLAM-1B accuracy drops by **12%**. + +--- + +## 5. Conclusion +APIGen enables **small models** to rival larger ones in function-calling tasks through high-quality data. Future work includes support for multi-turn interactions and additional API types. + +--- + +## Appendix +### A. Dataset Examples +**Simple Query**: +```json +{ + "query": "What is the weather in Palo Alto?", + "tools": [{ + "name": "weather_api.get_current_weather", + "parameters": {"location": "string"} + }], + "answers": [{ + "name": "weather_api.get_current_weather", + "arguments": {"location": "Palo Alto"} + }] +} +``` + +**Parallel Query**: +```json +{ + "query": "Sum multiples of 3 and 5 below 1000. Product of first 5 primes.", + "tools": ["math_toolkit.sum_of_multiples", "math_toolkit.product_of_primes"], + "answers": [ + {"name": "sum_of_multiples", "arguments": {"multiples": [3,5]}}, + {"name": "product_of_primes", "arguments": {"count": 5}} + ] +} +``` + +### B. Training Details +- **Hyperparameters**: Learning rate = 5e-6, 4 epochs, AdamW optimizer. +- **Hardware**: 8× NVIDIA A100 40GB GPUs. +- **Prompt Template**: Combines task instructions, available tools, and format constraints. + +--- +**License**: CC BY 4.0. +**Dataset Link**: [Huggingface](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) +``` \ No newline at end of file diff --git a/src/starfish/data_template/templates/starfish/func_call/generate_func_call_dataset.py b/src/starfish/data_template/templates/starfish/func_call/generate_func_call_dataset.py new file mode 100644 index 0000000..8d0bc2f --- /dev/null +++ b/src/starfish/data_template/templates/starfish/func_call/generate_func_call_dataset.py @@ -0,0 +1,120 @@ +from starfish.components.prepare_topic import generate_topics +from starfish.data_template.template_gen import data_gen_template +from pydantic import BaseModel + +from typing import Optional, Dict, Any +import random + +from starfish.data_template.templates.starfish.func_call.utils import ( + generate_sub_topic, + generator_query_answer, + update_query_answer, + verify_queries_with_llm, + num_records_, + sub_topic_num, +) +from starfish.common.logger import get_logger + +logger = get_logger(__name__) + + +class ParameterDefinition(BaseModel): + """ + Pydantic model representing parameter definition in an API contract. + """ + + type: str + description: str + required: bool = True + + +class APIContract(BaseModel): + """ + Pydantic model representing an API contract structure. + """ + + name: str + description: str + parameters: Dict[str, ParameterDefinition] + + +## Pydantic Input Schema +class GenerateFuncCallDataSet(BaseModel): + """ + Input schema for the generate_by_topic template. + + IMPORTANT: This Pydantic model is the single source of truth for default values. + The validation and default values are controlled by this model, not the function signature. + """ + + num_records: Optional[int] = 10 + api_contract: APIContract + topic_model_name: str = "openai/gpt-4o-mini" + topic_model_kwargs: Optional[Dict[str, Any]] = None + generation_model_name: str = "openai/gpt-4o-mini" + generation_model_kwargs: Optional[Dict[str, Any]] = None + data_factory_config: Optional[Dict[str, Any]] = {} + + +@data_gen_template.register( + name="starfish/generate_func_call_dataset", + input_schema=GenerateFuncCallDataSet, + output_schema=None, + description="""Generates diverse synthetic data across multiple topics based on user instructions. + Automatically creates relevant topics if not provided and handles deduplication across generated content. + """, + author="Wendao Liu", + starfish_version="0.1.3", + dependencies=[], +) +async def api_contract_workflow(input_data: GenerateFuncCallDataSet): + api_contract = input_data.api_contract.model_dump() + num_records = input_data.num_records + records_per_topic = sub_topic_num * num_records_ + num_topic = num_records // records_per_topic + num_records_reminder = num_records % records_per_topic + if num_records_reminder > 0 or num_topic == 0: + num_topic += 1 + user_instruction = api_contract["description"] + topic_model_name = input_data.topic_model_name + topic_model_kwargs = input_data.topic_model_kwargs + generation_model_name = input_data.generation_model_name + generated_topics = await generate_topics( + user_instruction=user_instruction, + num_topics=num_topic, + model_name=topic_model_name, + model_kwargs=topic_model_kwargs, + ) + sub_topic_input_data = [{"topic": topic, "user_instruction": user_instruction} for topic in generated_topics] + generate_sub_topic.factory.config.update(input_data.data_factory_config) + all_topics = generate_sub_topic.run(data=sub_topic_input_data, num_topics=sub_topic_num, model_name=topic_model_name, model_kwargs=topic_model_kwargs) + # generator_query_answer_input_data = random.sample(all_topics, num_records//num_records_) + generator_query_answer_input_data = all_topics + generator_query_answer.factory.config.update(input_data.data_factory_config) + query_answer_pairs = generator_query_answer.run(data=generator_query_answer_input_data, api_contract=api_contract, model_name=generation_model_name) + verify_queries_with_llm.factory.config.update(input_data.data_factory_config) + check_result = verify_queries_with_llm.run(data=query_answer_pairs, api_contract=api_contract, model_name=generation_model_name) + check_result_idx_arr = verify_queries_with_llm.get_index_completed() + failed_data_set = [] + result = [] + updated_data_set = [] + for i in range(0, len(check_result)): + item = check_result[i] + query_answer_pair = query_answer_pairs[check_result_idx_arr[i]] + if not item["format_checker_passed"] or not item["semantic_checker_passed"]: + query_answer_pair["failed_reason"] = item["reason"] + failed_data_set.append(query_answer_pair) + else: + result.append(query_answer_pair) + + if len(failed_data_set) > 0: + update_query_answer.factory.config.update(input_data.data_factory_config) + updated_data_set = update_query_answer.run(data=failed_data_set, api_contract=api_contract, model_name=generation_model_name) + check_result = verify_queries_with_llm.run(data=updated_data_set, api_contract=api_contract, model_name=generation_model_name) + filtered_check_result = [item for item in check_result if not item["format_checker_passed"] or not item["semantic_checker_passed"]] + if len(filtered_check_result) > 0: + logger.warning("verify_queries_with_llm promp need improve or the data set need improve") + + result.extend(updated_data_set) + # also try using duplicate-examples in the prompt + return result diff --git a/src/starfish/data_template/templates/starfish/func_call/utils.py b/src/starfish/data_template/templates/starfish/func_call/utils.py new file mode 100644 index 0000000..5265284 --- /dev/null +++ b/src/starfish/data_template/templates/starfish/func_call/utils.py @@ -0,0 +1,316 @@ +from starfish import data_factory, StructuredLLM +from starfish.components.prepare_topic import generate_topics +from typing import Optional, List, Dict, Any +from starfish.common.logger import get_logger + +logger = get_logger(__name__) +num_records_ = 5 +sub_topic_num = 5 + + +@data_factory(max_concurrency=20) +async def verify_queries_with_llm(model_name, query, answer, api_contract): + """ + Uses an LLM to verify if generated queries match the API contract. + + Args: + query_answer_pairs (list): A list of dictionaries with 'query' and 'answer' keys. + api_contract (dict): The API contract dictionary. + """ + + semantic_checker_llm = StructuredLLM( + model_name=model_name, # You can choose a different model + prompt="""Given this API contract: {{api_contract}}, this query: '{{query}}', and this answer: '{{answer}}'. + Here is an example for your reference ONLY - DO NOT compare it directly with the given query/answer pair: + Given the API contract as + { + "name": "weather_api.get_current_weather", + "description": "Retrieves the current weather conditions for a specified location .", + "parameters": { + "location": { + "type": "string", + "description": "The name of the city or geographic location .", + "required": True + }, + "units": { + "type": "string", + "description": "The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .", + "required": False + }, + }, + }, + A valid query/answer pair would be something similar to this: + Query: "Could you check the weather in Nairobi, Buenos Aires, and Bangkok? Also, I'd like to know the wind speed in Jakarta." + Answer: [ + {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Nairobi'}}, + {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Buenos Aires','units': 'Fahrenheit'}}, + {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Bangkok'}}, + {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Jakarta'}} + ] + + Analyze the following aspects of the given query/answer pair (NOT the example): + 1. Does the query contain all necessary information that the API contract requires? + 2. Does the answer contain the correct number of function calls matching the requests in the query? + 3. Does each function call in the answer: + - Use the correct API name as specified in the contract? + - Include all required parameters from the API contract (optional parameters are not necessary)? + - Use parameter values that semantically match the API contract's parameters description? + + Respond with 'Yes' or 'No', followed by a detailed reason explaining your analysis. + If 'No', specify which aspect(s) failed and why.""", + output_schema=[ + {"name": "match", "type": "str"}, # e.g., "Yes" or "No" + {"name": "reason", "type": "str"}, + ], + model_kwargs={"temperature": 0.3}, # Lower temperature for more deterministic output + ) + + format_checker_passed = True + semantic_checker_passed = False + reason_arr = [] + if query and answer: + if isinstance(query, str) and isinstance(answer, (list, dict)): + if isinstance(answer, dict): + answer = [answer] + for item in answer: + # Basic structure validation + if not isinstance(item, dict): + format_checker_passed = False + reason_arr.append("Answer items must be dictionaries") + continue + + # Required keys check + if "name" not in item or "arguments" not in item: + format_checker_passed = False + reason_arr.append("Answer items must contain 'name' and 'arguments' keys") + continue + + # Arguments type check + if not isinstance(item["arguments"], dict): + format_checker_passed = False + reason_arr.append("Arguments must be a dictionary") + continue + + # Function name validation + if item["name"].strip() != api_contract["name"].strip(): + format_checker_passed = False + reason_arr.append("function name not match with the api_contract") + continue + + # Parameter validation + required_params = {param for param, details in api_contract["parameters"].items() if details.get("required", True)} + api_params = set(api_contract["parameters"].keys()) + answer_args = set(item["arguments"].keys()) + + # Required parameters check + if not required_params.issubset(answer_args): + format_checker_passed = False + reason_arr.append(f"Missing required parameters: {required_params - answer_args}") + continue + + # Parameter keys validation + if not answer_args.issubset(api_params): + format_checker_passed = False + reason_arr.append(f"Arguments {answer_args} must be subset of API parameters {api_params}") + continue + + # Type checking for each argument + for arg_name, arg_value in item["arguments"].items(): + expected_type = api_contract["parameters"][arg_name]["type"].lower() + + type_checks = { + "string": isinstance(arg_value, str), + "number": isinstance(arg_value, (int, float)), + "float": isinstance(arg_value, (int, float)), + "integer": isinstance(arg_value, int), + "boolean": isinstance(arg_value, bool), + "array": isinstance(arg_value, list), + "object": isinstance(arg_value, dict), + "null": arg_value is None, + } + + if not type_checks.get(expected_type, True): + format_checker_passed = False + reason_arr.append(f"Argument '{arg_name}' should be {expected_type}, got {type(arg_value)}") + continue + + # Add custom type checks here if needed + if format_checker_passed: + semantic_checker_llm_result = await semantic_checker_llm.run(api_contract=api_contract, query=query, answer=answer) + if semantic_checker_llm_result and hasattr(semantic_checker_llm_result, "data") and semantic_checker_llm_result.data: + result = semantic_checker_llm_result.data[0] # Assuming one output per run + match_status = result.get("match") + reason = result.get("reason") + logger.info(f"Query: '{query}'") + logger.info(f"Answer: '{answer}'") + logger.info(f" semantic checker Result: {match_status}") + logger.info(f" Reason: {reason}") + reason_arr.append(reason) + if match_status == "Yes": + semantic_checker_passed = True + else: + logger.info(f"Query: '{query}' - LLM semantic checker failed.") + return [{"format_checker_passed": format_checker_passed, "semantic_checker_passed": semantic_checker_passed, "reason": "; ".join(reason_arr)}] + + +@data_factory(max_concurrency=10) +async def generate_sub_topic( + user_instruction: str, + num_topics: int, + topic: str, + model_name: str = "openai/gpt-4o-mini", + model_kwargs: Optional[Dict[str, Any]] = None, + existing_topics: Optional[List[str]] = None, +): + user_instruction = f"{user_instruction}, generate sub topics with different location in each sub topic based on the topic of {topic}" + sub_topics = await generate_topics( + user_instruction=user_instruction, num_topics=num_topics, model_name=model_name, model_kwargs=model_kwargs, existing_topics=existing_topics + ) + result = [{"topic": topic} for topic in sub_topics] + return result + + +# here is an example to follow, giving the function name/description/parameters. watch for the paramter which is optional in {{func_params}} +# function name as "weather_api.get_current_weather", +# function description as "Retrieves the current weather conditions for a specified location .", +# function parameters as { +# "location": { +# "type": "string", +# "description": "The name of the city or geographic location .", +# "required": True +# }, +# "units": { +# "type": "string", +# "description": "The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .", +# "required": False +# }, +# }, +# this is the output given the conditions above. +# { +# "query" : "Could you check the weather in Nairobi, Buenos Aires, and Bangkok? Also, I'd like to know the wind speed in Jakarta.", +# "answer" : [ +# {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Nairobi'}}, +# {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Buenos Aires'}}, +# {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Bangkok', "units":"Celsius"}}, +# {'name': 'weather_api.get_current_weather', 'arguments': {'location': 'Jakarta'}} +# ] +# } +@data_factory(max_concurrency=30) +async def generator_query_answer(model_name, api_contract: dict, topic: str): + query_answer_generator_prompt = """ + You are a data labeler. The responsibility for you is to generate a set of diverse queries and corresponding answers for the given functions in JSON format. + Construct queries and answers that exemplifies how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date. + Ensure the query: + − Is clear and concise + − Contain multiple parallel queries in natural language for the given functions, they could use either the same function with different arguments or different functions + − Demonstrates typical use cases + − Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numerals or words + − Across a variety level of difficulties, ranging from beginner and advanced use cases + − The corresponding result's parameter types and ranges match with the functions descriptions. + Ensure the answer: + − Is a list of function calls in JSON format. + − The length of the answer list should be equal to the number of requests in the query + − Can solve all the requests in the query effectively + + Note that the query could be interpreted as a combination of several independent requests. + Based on these examples and the above instructions, generate {{num_records}} diverse query and answer pairs for the functions '{{func_name}}'. + The detailed functions description is as follows: + {{func_desc}} + The detailed functions paramters is as follows, the generated outputs shall have some records having the optional parameters: + {{func_params}} + The output MUST strictly adhere to the following JSON format, and NO other text MUST be included: + [ + { + "query": "The generated query.", + "answers": [ + { + "name": "api_name", + "arguments": { + "arg_name": "value", + ... (more arguments as required) + } + }, + ... (more API calls as required) + ] + } + ] + + Now please generate {{num_records}} diverse query and answer pairs following the above format. + """ + query_answer_generator = StructuredLLM( + model_name=model_name, + # model_name="openai/gpt-4o-mini", + prompt=query_answer_generator_prompt, + output_schema=[ + {"name": "query", "type": "str"}, + {"name": "answer", "type": "str"}, + ], + model_kwargs={"temperature": 0.7}, + ) + query_answer_pairs = await query_answer_generator.run( + func_name=api_contract["name"], + func_desc=api_contract["description"] + " in this topic : " + topic, + func_params=api_contract["parameters"], + num_records=num_records_, + ) + return query_answer_pairs.data + + +@data_factory(max_concurrency=30) +async def update_query_answer(model_name: str, api_contract: dict, query, answer, failed_reason): + update_answer_generator_prompt = """ + Given this API contract: {{api_contract}}, this query: '{{query}}', and this answer: '{{answer}}'. It failed the format or semantic check + with this reason {{reason}}. + Please update the answer to pass the check. Here is the requirement + + Ensure the query be the same: + Ensure the answer: + − Is a list of function calls in JSON format. + − The length of the answer list should be equal to the number of requests in the query + − Can solve all the requests in the query effectively + + Based on these examples and the above instructions, update query and answer pair for the functions '{{func_name}}'. + The detailed functions description is as follows: + {{func_desc}} + The detailed functions paramters is as follows, the generated outputs shall have some records having the optional parameters: + {{func_params}} + The output MUST strictly adhere to the following JSON format, and NO other text MUST be included: + [ + { + "query": "The generated query.", + "answer": [ + { + "name": "api_name", + "arguments": { + "arg_name": "value", + ... (more arguments as required) + } + }, + ... (more API calls as required) + ] + } + ] + + update the answer to pass the check. + """ + + query_answer_updator = StructuredLLM( + model_name=model_name, + prompt=update_answer_generator_prompt, + output_schema=[ + {"name": "query", "type": "str"}, + {"name": "answer", "type": "str"}, + ], + model_kwargs={"temperature": 0.7}, + ) + query_answer_pairs = await query_answer_updator.run( + api_contract=api_contract, + func_name=api_contract["name"], + func_desc=api_contract["description"], + func_params=api_contract["parameters"], + num_records=1, + query=query, + answer=answer, + reason=failed_reason, + ) + return query_answer_pairs.data diff --git a/src/starfish/data_template/templates/starfish/gen_topics/generate_by_topic.py b/src/starfish/data_template/templates/starfish/gen_topics/generate_by_topic.py new file mode 100644 index 0000000..edd75d6 --- /dev/null +++ b/src/starfish/data_template/templates/starfish/gen_topics/generate_by_topic.py @@ -0,0 +1,114 @@ +from starfish import data_factory, StructuredLLM +from starfish.components.prepare_topic import prepare_topic +from starfish.data_template.template_gen import data_gen_template +from pydantic import BaseModel + +from typing import Optional, List, Union, Dict, Any +import random + +from starfish.data_template.templates.starfish.gen_topics.utils import fetch_values_by_topic, save_value_by_topic + + +## Pydantic Input Schema +class GenerateByTopicInput(BaseModel): + """ + Input schema for the generate_by_topic template. + + IMPORTANT: This Pydantic model is the single source of truth for default values. + The validation and default values are controlled by this model, not the function signature. + """ + + user_instruction: Optional[str] = None + num_records: Optional[int] = 10 + records_per_topic: int = 10 + topics: Optional[List[Union[str, Dict[str, int]]]] = None + topic_model_name: str = "openai/gpt-4o-mini" + topic_model_kwargs: Optional[Dict[str, Any]] = None + generation_model_name: str = "openai/gpt-4o-mini" + generation_model_kwargs: Optional[Dict[str, Any]] = None + output_schema: Optional[Union[List[Dict[str, Any]], Dict[str, Any], type]] = [{"name": "question", "type": "str"}, {"name": "answer", "type": "str"}] + data_factory_config: Optional[Dict[str, Any]] = {} + + +## Main +@data_gen_template.register( + name="starfish/generate_by_topic", + input_schema=GenerateByTopicInput, + output_schema=None, + description="""Generates diverse synthetic data across multiple topics based on user instructions. + Automatically creates relevant topics if not provided and handles deduplication across generated content. + """, + author="Wendao Liu", + starfish_version="0.1.3", + dependencies=[], +) +async def generate_by_topic(input_data: GenerateByTopicInput): + """ + Generates diverse synthetic data across multiple topics based on user instructions and defined output schema. + + If topics are not provided, it automatically generates relevant topics based on the instruction. + The function reduce deduplication by tracking previously generated examples for each topic and + avoids repeating similar content. + + The data generation process has two main phases: first, topics are prepared (either using + provided topics or generating them); second, data is generated for each topic. + Topics are shuffled to ensure even distribution and better deduplication in the output data. + + Each generated record includes the topic it belongs to. + """ + topic_list = await prepare_topic( + topics=input_data.topics, + num_records=input_data.num_records, + records_per_topic=input_data.records_per_topic, + user_instruction=input_data.user_instruction, + model_name=input_data.topic_model_name, + model_kwargs=input_data.topic_model_kwargs, + ) + ## Shuffle the topic list to be more eventually distributed for better deduplication + random.shuffle(topic_list) + + @data_factory(**input_data.data_factory_config) + async def batch_generate_record(topic: str): + ## duplicate_example + generated_data = fetch_values_by_topic(batch_generate_record.state, topic) + if generated_data: + duplicate_example = random.choice(generated_data) + else: + duplicate_example = None + + prompt = """ + You are a helpful synthetic data generation assistant. + Your task is to generate synthetic data based on the provided information. + + User instructions are provided: + - Carefully consider the user instructions + - Create data that aligns with these instructions + + Please generate the synthetic data based on the given information and guidelines. + + here is user_instruction: {{user_instruction}} + + please generate sythentic data in this '{{topic}}' topics or themes + + {% if duplicate_example %} + To avoid duplication, here are samples of existing data and please do not repeat from this: + {{duplicate_example}} + {% endif %} + """ + + generation_llm = StructuredLLM( + prompt=prompt, model_name=input_data.generation_model_name, model_kwargs=input_data.generation_model_kwargs, output_schema=input_data.output_schema + ) + + generated_response = await generation_llm.run(user_instruction=input_data.user_instruction, topic=topic, duplicate_example=duplicate_example) + + save_value_by_topic(batch_generate_record.state, topic, str(generated_response.data)) + + ## Adding topic to the data and there is only one record so it is safe to use index 0 + generated_response.data[0]["topic"] = topic + + return generated_response.data + + data = batch_generate_record.run(topic_list) + + return data diff --git a/src/starfish/data_template/templates/starfish/gen_topics/utils.py b/src/starfish/data_template/templates/starfish/gen_topics/utils.py new file mode 100644 index 0000000..9a2eb9c --- /dev/null +++ b/src/starfish/data_template/templates/starfish/gen_topics/utils.py @@ -0,0 +1,34 @@ +from starfish.data_factory.utils.state import MutableSharedState +from typing import Any + + +## Helper Functions +def save_value_by_topic(state: MutableSharedState, topic: str, value: Any) -> None: + """Saves a value indexed by topic in the shared state.""" + # Get current state data + topic_collections = state.get("topic_data") + if topic_collections is None: + topic_collections = {} + else: + topic_collections = topic_collections.copy() + + # Initialize topic collection if needed + with state._lock: + if topic not in topic_collections or not isinstance(topic_collections[topic], list): + topic_collections[topic] = [] + # Append the value and update state + topic_collections[topic].append(value) + state.set("topic_data", topic_collections) + + +def fetch_values_by_topic(state: MutableSharedState, topic: str) -> list: + """Fetches all values indexed by a topic from the shared state.""" + topic_collections = state.get("topic_data") + if not topic_collections: + return [] + + topic_values = topic_collections.get(topic) + if not topic_values or not isinstance(topic_values, list): + return [] + + return topic_values.copy() diff --git a/src/starfish/data_template/templates/starfish/generate_city_info.py b/src/starfish/data_template/templates/starfish/generate_city_info.py deleted file mode 100644 index 0b2f302..0000000 --- a/src/starfish/data_template/templates/starfish/generate_city_info.py +++ /dev/null @@ -1,64 +0,0 @@ -from pydantic import BaseModel -from starfish import data_factory -from starfish.data_template.template_gen import data_gen_template -from starfish.data_factory.utils.mock import mock_llm_call -import random -from typing import Any, List, Dict, Callable - -from starfish.common.logger import get_logger -from starfish.data_factory.constants import ( - STATUS_COMPLETED, - STATUS_DUPLICATE, - STATUS_FAILED, - STORAGE_TYPE_LOCAL, -) -from starfish.data_factory.factory import data_factory, resume_from_checkpoint -from starfish.data_factory.utils.mock import mock_llm_call -from starfish.data_factory.utils.state import MutableSharedState - -logger = get_logger(__name__) - - -# Define input schema -class CitiInfoGeneratorInput(BaseModel): - region_code: list[str] - city_name: list[str] - - -@data_gen_template.register( - name="starfish/generate_city_info", - input_schema=CitiInfoGeneratorInput, - output_schema=None, - description="Generates relevant topics for community discussions using AI models", - author="Your Name", - starfish_version="0.1.0", - dependencies=[], -) -def generate_city_info(input_data: dict = {}): - # city_name = input_data["city_name"] - # region_code = input_data["region_code"] - city_name = ["San Francisco", "New York", "Los Angeles"] * 5 - region_code = ["DE", "IT", "US"] * 5 - - @data_factory( - storage=STORAGE_TYPE_LOCAL, - max_concurrency=50, - initial_state_values={}, - on_record_complete=[], - on_record_error=[], - show_progress=True, - task_runner_timeout=60, - ) - async def get_city_info_wf(city_name: List[str], region_code: List[str]) -> List[Dict[str, Any]]: - """Retrieve information about cities using a workflow. - - Args: - city_name: Name(s) of the city/cities to get information for - region_code: Region code(s) associated with the city/cities - - Returns: - List[Dict[str, Any]]: Processed city information from the mock LLM call - """ - return await mock_llm_call(city_name, num_records_per_city=1, fail_rate=0.1, sleep_time=2) - - return get_city_info_wf.run(city_name=city_name, region_code=region_code) diff --git a/src/starfish/data_template/templates/starfish/get_city_info_wf.py b/src/starfish/data_template/templates/starfish/get_city_info_wf.py deleted file mode 100644 index 898fb98..0000000 --- a/src/starfish/data_template/templates/starfish/get_city_info_wf.py +++ /dev/null @@ -1,116 +0,0 @@ -from pydantic import BaseModel -from starfish import data_factory -from starfish.data_template.template_gen import data_gen_template -from starfish.data_factory.utils.mock import mock_llm_call -import random -from typing import Any, List, Dict, Callable - -from starfish.common.logger import get_logger -from starfish.data_factory.constants import ( - STATUS_COMPLETED, - STATUS_DUPLICATE, - STATUS_FAILED, - STORAGE_TYPE_LOCAL, -) -from starfish.data_factory.factory import data_factory, resume_from_checkpoint -from starfish.data_factory.utils.mock import mock_llm_call -from starfish.data_factory.utils.state import MutableSharedState - -logger = get_logger(__name__) - - -# def handle_error(data: Any, state: MutableSharedState): -# """Handle error cases during data processing. - -# Args: -# data: The data that caused the error -# state: Shared state object for tracking progress - -# Returns: -# str: STATUS_FAILED constant -# """ -# logger.error(f"Error occurred: {data}") -# return STATUS_FAILED - - -# def handle_record_complete(data: Any, state: MutableSharedState): -# """Handle successful completion of a record. - -# Args: -# data: The successfully processed data -# state: Shared state object for tracking progress - -# Returns: -# str: STATUS_COMPLETED constant -# """ -# # print(f"Record complete: {data}") - -# state.set("completed_count", 1) -# state.update({"completed_count": 2}) -# return STATUS_COMPLETED - - -# def handle_duplicate_record(data: Any, state: MutableSharedState): -# """Handle duplicate record detection. - -# Args: -# data: The duplicate data record -# state: Shared state object for tracking progress - -# Returns: -# str: Either STATUS_COMPLETED or STATUS_DUPLICATE based on random chance -# """ -# logger.debug(f"Record : {data}") -# state.set("completed_count", 1) -# state.update({"completed_count": 2}) -# # return STATUS_DUPLICATE -# if random.random() < 0.9: -# # print("going to return completed") -# return STATUS_COMPLETED -# # print("going to return duplicate") -# return STATUS_DUPLICATE - - -# Define input schema -class CitiInfoGeneratorInput(BaseModel): - region_code: list[str] - city_name: list[str] - - -# # Define output schema -# class TopicGeneratorOutput(BaseModel): -# generated_topics: list[str] -# success: bool -# message: str - - -# "transformers>=4.0.0", -@data_gen_template.register( - name="starfish/get_city_info_wf", - input_schema=None, - output_schema=None, - description="Generates relevant topics for community discussions using AI models", - author="Your Name", - starfish_version="0.1.0", - dependencies=[], -) -@data_factory( - storage=STORAGE_TYPE_LOCAL, - max_concurrency=50, - initial_state_values={}, - on_record_complete=[], - on_record_error=[], - show_progress=True, - task_runner_timeout=60, -) -async def get_city_info_wf(city_name: List[str], region_code: List[str]) -> List[Dict[str, Any]]: - """Retrieve information about cities using a workflow. - - Args: - city_name: Name(s) of the city/cities to get information for - region_code: Region code(s) associated with the city/cities - - Returns: - List[Dict[str, Any]]: Processed city information from the mock LLM call - """ - return await mock_llm_call(city_name, num_records_per_city=1, fail_rate=0.1, sleep_time=2) diff --git a/src/starfish/data_template/templates/starfish/math_problem_gen_wf.py b/src/starfish/data_template/templates/starfish/math_problem_gen_wf.py deleted file mode 100644 index 841a8d8..0000000 --- a/src/starfish/data_template/templates/starfish/math_problem_gen_wf.py +++ /dev/null @@ -1,222 +0,0 @@ -from pydantic import BaseModel -from starfish import data_factory -from starfish.data_template.template_gen import data_gen_template - -# import nest_asyncio -from starfish import data_factory, StructuredLLM -import os -from agents import Agent, Runner, function_tool -from agents.agent_output import AgentOutputSchema -from pydantic import BaseModel - -# nest_asyncio.apply() - -model_name_used = "openai/gpt-4.1-mini" -CONCURRENCY = 50 -TASK_RUNNER_TIMEOUT = 500 - - -# Define input schema -class TopicGeneratorInput(BaseModel): - community_name: str - seed_topics: list[str] - num_topics: int - language: str = "en" - - -# Define output schema -class TopicGeneratorOutput(BaseModel): - generated_topics: list[str] - success: bool - message: str - - -@data_gen_template.register( - name="starfish/math_problem_gen_wf", - input_schema=TopicGeneratorInput, - # optional - output_schema=TopicGeneratorOutput, - description="Generates relevant math problem-solutions using AI models", - author="Your Name", - starfish_version="0.1.0", - dependencies=["posthog>=3.11.0"], -) -def math_problem_gen_wf(): - @data_factory(max_concurrency=CONCURRENCY) - async def generate_topic(num_records): - prompt = """ - List unique math topics that are commonly tested on AIME (American Invitational Mathematics Examination) problems. - Focus on areas that appear frequently in recent years, especially 2020–2025. - Include both core topics and more niche subtopics. - """ - model = StructuredLLM(model_name=model_name_used, prompt=prompt, output_schema=[{"name": "topic", "type": "str", "required": True}]) - return (await model.run(num_records=num_records)).data - - @data_factory(max_concurrency=CONCURRENCY) - async def generate_problem(topic): - prompt = """ - Create a AIME-style math competition problem in the topic of {{topic}}. - - Requirements: - - 1. The problem should be original and adhere to AIME difficulty (appropriate for high school students aiming for USAMO qualification). - 2. It must be solvable in 3 to 6 logical steps, without requiring computational brute force. - 3. Emphasize creativity, clean setup, and an elegant path to the solution. - 4. Use clear and concise language. No extraneous details. - 5. Do not include the answer or any solution steps. - 6. Return only the problem text. - """ - model = StructuredLLM( - model_name=model_name_used, - prompt=prompt, - output_schema=[{"name": "problem", "type": "str", "required": True}, {"name": "topic", "type": "str", "required": True}], - ) - return (await model.run(topic=topic)).data - - # Step 1: Define your desired structured output - class CoTSchema(BaseModel): - cot: str - problem: str - topic: str - answer: str - - @data_factory(max_concurrency=CONCURRENCY, task_runner_timeout=TASK_RUNNER_TIMEOUT) - async def answer_long_cot(problem, topic): - prompt = f"Solve the following problem using detailed, step-by-step reasoning. Conclude with Final Answer: . Problem: {problem}" - - my_agent = Agent(name="Problem solver with detailed CoT", output_type=CoTSchema) - - sample_run = await Runner.run(my_agent, input=prompt) - - output = sample_run.final_output.model_dump() - output["cot_type"] = "long" - output["topic"] = topic - return [output] - - @data_factory(max_concurrency=CONCURRENCY) - async def generate_short_cot(problem): - prompt = f"Solve this problem using concise step-by-step reasoning. End with: Final Answer: . Problem: {problem}" - my_agent = Agent(name="Problem solver with concise CoT", output_type=CoTSchema) - - sample_run = await Runner.run(my_agent, input=prompt) - - output = sample_run.final_output.model_dump() - return [output] - - # Step 1: Define your desired structured output - class CodeExecutorSchema(BaseModel): - verified: str - correct_answer: int - topic: str - problem: str - cot: str - - @function_tool - def execute_python_code(code: str): - local_vars = {} - exec(code, {}, local_vars) - verified = local_vars.get("verified", None) - correct_solution = local_vars.get("correct_solution", None) - return {"verified": bool(verified), "correct_solution": correct_solution} - - @data_factory(max_concurrency=CONCURRENCY, task_runner_timeout=TASK_RUNNER_TIMEOUT) - async def data_factory_execute_cot_as_code(cot, answer, topic, problem, cot_type): - return await execute_cot_as_code(cot, answer, topic, problem, cot_type) - - async def execute_cot_as_code(cot, answer, topic, problem, cot_type): - prompt = f""" - Convert the following AIME-style math problem into precise, correct, and executable Python code. - - Instructions: - - Implement a complete solution that rigorously follows all mathematical constraints. - - Use Python libraries such as `itertools`, `math`, or `collections` when appropriate. - - If the problem involves enumeration, number theory, or digit-based logic, handle all edge cases carefully and thoroughly. - - Avoid floating-point operations when integer accuracy is required. - - Assign the final result to a variable named `correct_solution`. - - Compare `correct_solution` against the expected value `{answer}` and set `verified = True` if they match, otherwise `False`. - - After writing the code, call the tool to execute it. - - Problem: - {problem} - """ - - my_agent = Agent( - name="Tool caller", - output_type=CodeExecutorSchema, - tools=[execute_python_code], - ) - - sample_run = await Runner.run(my_agent, input=prompt) - - output = sample_run.final_output.model_dump() - output["problem"] = problem - output["cot"] = cot - output["cot_type"] = cot_type - output["topic"] = topic - - return [output] - - # Step 1: Define your desired structured output - class FeedbackAndRewriteSchema(BaseModel): - revised_cot: str - cot: str - topic: str - problem: str - - @data_factory(max_concurrency=CONCURRENCY, task_runner_timeout=TASK_RUNNER_TIMEOUT) - async def feedback_and_rewrite(topic, problem, cot, verified, correct_answer, cot_type): - prompt = f""" - Review the problem and the current solution attempt below. - - First, evaluate whether the reasoning in the solution leads to the correct answer. If it does not, identify any mistakes or incorrect steps. Then, rewrite the solution so that the logic is accurate and clearly leads to the correct, verified answer. - - Your rewritten solution should maintain a step-by-step explanation and ensure the final result matches the Correct Answer. - - Make sure that this revised and rewriteen chain of through is is returned in the variable `revised_cot` - - Problem: {problem} - Current Solution: {cot} - Verified Correct Answer: {correct_answer} - """ - my_agent = Agent(name="Feedback and Rewrite", output_type=FeedbackAndRewriteSchema) - - sample_run = await Runner.run(my_agent, input=prompt) - - output = sample_run.final_output.model_dump() - - feedbacked_output = await execute_cot_as_code(cot=output["revised_cot"], topic=topic, problem=problem, answer=correct_answer, cot_type=cot_type) - return feedbacked_output - - topics = generate_topic.run(num_records=10) - problem = generate_problem.run(topics) - long_cot = answer_long_cot.run(problem) - cot_as_code = data_factory_execute_cot_as_code.run(long_cot) - all_re_written_cots = [] - - # Start with only unverified entries (string comparison) - unverified_entries = [entry for entry in cot_as_code if entry.get("verified") == "False"] - - verified_entries = [entry for entry in cot_as_code if entry.get("verified") == "True"] - - print("VERIFIED ENTRIES: " + str(len(verified_entries))) - - print("UNVERIFIED ENTRIES: " + str(len(unverified_entries))) - - while unverified_entries: - # Run feedback and rewrite on the current batch of unverified entries - rewritten_batch = feedback_and_rewrite.run(unverified_entries) - - # Collect verified rewrites - verified_batch = [rewritten for rewritten in rewritten_batch if rewritten.get("verified") == "True"] - all_re_written_cots.extend(verified_batch) - - # Remove verified entries from the current unverified list - unverified_entries = [rewritten for rewritten in rewritten_batch if rewritten.get("verified") == "False"] - print("ALL REWRITTEN ENTRIES: " + str(len(all_re_written_cots))) - - verified_entries = verified_entries + all_re_written_cots - print(verified_entries) - - -# math_problem_gen_wf() diff --git a/tests/data_template/test_data_template.py b/tests/data_template/test_data_template.py index 352f678..855f35a 100644 --- a/tests/data_template/test_data_template.py +++ b/tests/data_template/test_data_template.py @@ -1,29 +1,13 @@ import nest_asyncio import pytest -from pydantic import BaseModel +import os from starfish.common.env_loader import load_env_file from starfish.data_template.template_gen import data_gen_template -from starfish.data_template.utils.error import DataTemplateValueError, ImportPackageError nest_asyncio.apply() load_env_file() -# Define input schema -class TopicGeneratorInput(BaseModel): - community_name: str - seed_topics: list[str] - num_topics: int - language: str = "en" - - -# Define output schema -class TopicGeneratorOutput(BaseModel): - generated_topics: list[str] - success: bool - message: str - - @pytest.mark.asyncio async def test_list(): """Test with input data and broadcast variables @@ -47,123 +31,59 @@ async def test_list_detail(): @pytest.mark.asyncio -# @pytest.mark.skip(reason="Skipping this test case as not implementing data_factory decorator outside the function") -async def test_get_datafactory_run(): - """Test with input data and broadcast variables - - Input: List of dicts with city names - - Broadcast: num_records_per_city - - Expected: All cities processed successfully - """ - data_gen_template.list() - get_city_info_wf = data_gen_template.get("starfish/get_city_info_wf") - results = get_city_info_wf.run( - city_name=["San Francisco", "New York", "Los Angeles"] * 5, - region_code=["DE", "IT", "US"] * 5, - ) - assert len(results) == 15 - - -@pytest.mark.asyncio -async def test_get_run_dependencies_not_met(): - """Test with input data and broadcast variables - - Input: List of dicts with city names - - Broadcast: num_records_per_city - - Expected: All cities processed successfully - """ - data_gen_template.list() - with pytest.raises(ModuleNotFoundError): - topic_generator = data_gen_template.get("starfish/math_problem_gen_wf") - - -@pytest.mark.asyncio -async def test_get_run_dependencies_not_met(): - """Test with input data and broadcast variables - - Input: List of dicts with city names - - Broadcast: num_records_per_city - - Expected: All cities processed successfully - """ - data_gen_template.list() - with pytest.raises(ImportPackageError): - topic_generator = data_gen_template.get("community/topic_generator") - - -@pytest.mark.asyncio -async def test_get_run_template_Input_Success(): - """Test with input data and broadcast variables - - Input: List of dicts with city names - - Broadcast: num_records_per_city - - Expected: All cities processed successfully - """ +@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping in CI environment") +async def test_get_generate_by_topic_Success(): data_gen_template.list() - template = data_gen_template.get("community/topic_generator_success") - - # input_data = TopicGeneratorInput( - # community_name="AI Enthusiasts", - # seed_topics=["Machine Learning", "Deep Learning"], - # num_topics=5 - # ) - input_data = {"community_name": "AI Enthusiasts", "seed_topics": ["Machine Learning", "Deep Learning"], "num_topics": 1} - # results = template.run(input_data.model_dump()) - results = template.run(input_data) - print(results) - - assert len(results.generated_topics) == 3 - - -@pytest.mark.asyncio -async def test_get_run_template_Input_Schema_Not_Match(): - """Test with input data and broadcast variables - - Input: List of dicts with city names - - Broadcast: num_records_per_city - - Expected: All cities processed successfully - """ - with pytest.raises(DataTemplateValueError) as exc_info: - topic_generator = data_gen_template.get("community/topic_generator_success_1") - - input_data = TopicGeneratorInput( - community_name="AI Enthusiasts", - ) - topic_generator.run(input_data) - - # Assert the error message - assert "Template community/topic_generator_success_1 not found" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_get_cities_info(): - """Get information about multiple cities using the city info workflow template. - - This function retrieves city information by executing the 'get_city_info_wf' template - with provided city names and region codes. - - Args: - city_name: List of city names to get information for - region_code: List of region codes corresponding to the cities (e.g., state codes for US) - - Returns: - Results from the city info workflow template execution containing city information - - Example: - >>> await test_get_cities_info( - ... city_name=["New York", "Los Angeles"], - ... region_code=["NY", "CA"] - ... ) - """ - data_gen_template.list() - template = data_gen_template.get("starfish/get_city_info_wf") - - city_name = ["San Francisco", "New York", "Los Angeles"] * 5 - region_code = ["DE", "IT", "US"] * 5 - results = template.run(city_name=city_name, region_code=region_code) - return results + topic_generator_temp = data_gen_template.get("starfish/generate_by_topic") + num_records = 20 + input_data = { + "user_instruction": "Generate Q&A pairs about machine learning concepts", + "num_records": num_records, + "records_per_topic": 5, + "topics": [ + "supervised learning", + "unsupervised learning", + {"reinforcement learning": 3}, # This means generate 3 records for this topic + "neural networks", + ], + "topic_model_name": "openai/gpt-4", + "topic_model_kwargs": {"temperature": 0.7}, + "generation_model_name": "openai/gpt-4", + "generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, + "output_schema": [ + {"name": "question", "type": "str"}, + {"name": "answer", "type": "str"}, + {"name": "difficulty", "type": "str"}, # Added an additional field + ], + "data_factory_config": {"max_concurrency": 4, "task_runner_timeout": 60 * 2}, + } + # results = topic_generator_temp.run(input_data.model_dump()) + results = await topic_generator_temp.run(input_data) + + assert len(results) == num_records @pytest.mark.asyncio -async def test_gen_cities_info(): +@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping in CI environment") +async def test_get_generate_func_call_dataset(): data_gen_template.list() - city_name = ["San Francisco", "New York", "Los Angeles"] * 5 - region_code = ["DE", "IT", "US"] * 5 - input_data = {"city_name": city_name, "region_code": region_code} - template = data_gen_template.get("starfish/generate_city_info") - results = template.run(input_data) - assert len(results) == 15 + generate_func_call_dataset = data_gen_template.get("starfish/generate_func_call_dataset") + input_data = { + "num_records": 4, + "api_contract": { + "name": "weather_api.get_current_weather", + "description": "Retrieves the current weather conditions for a specified location .", + "parameters": { + "location": {"type": "string", "description": "The name of the city or geographic location .", "required": True}, + "units": {"type": "string", "description": "The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .", "required": False}, + }, + }, + "topic_model_name": "openai/gpt-4", + "topic_model_kwargs": {"temperature": 0.7}, + "generation_model_name": "openai/gpt-4o-mini", + "generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, + "data_factory_config": {"max_concurrency": 24, "task_runner_timeout": 60 * 2}, + } + results = await generate_func_call_dataset.run(input_data) + + assert len(results) >= input_data["num_records"]