Skip to content
28 changes: 14 additions & 14 deletions src/starfish/components/prepare_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


async def generate_topics(
user_instructions: str,
user_instruction: str,
num_topics: int,
model_name: str = "openai/gpt-4o-mini",
model_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -30,7 +30,7 @@ async def generate_topics(
for _ in range(num_batches):
topic_generator = StructuredLLM(
model_name=model_name,
prompt="""Can you generate a list of topics about {{user_instructions}}
prompt="""Can you generate a list of topics about {{user_instruction}}
{% if existing_topics_str %}
Please do not generate topics that are already in the list: {{existing_topics_str}}
Make sure the topics are unique and vary from each other
Expand All @@ -41,7 +41,7 @@ async def generate_topics(
)

all_existing = existing_topics + generated_topics
input_params = {"user_instructions": user_instructions, "num_records": min(llm_batch_size, num_topics - len(generated_topics))}
input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))}

if all_existing:
input_params["existing_topics_str"] = ",".join(all_existing)
Expand All @@ -60,7 +60,7 @@ async def prepare_topic(
topics: Optional[List[Union[str, Dict[str, int]]]] = None,
num_records: Optional[int] = None,
records_per_topic: int = 20,
user_instructions: Optional[str] = None,
user_instruction: Optional[str] = None,
model_name: str = "openai/gpt-4o-mini",
model_kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
Expand All @@ -70,13 +70,13 @@ async def prepare_topic(
1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution
2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts
3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats
4. None: No topics provided, will generate based on user_instructions
4. None: No topics provided, will generate based on user_instruction

Args:
topics: Optional list of topics, either strings or {topic: count} dicts
num_records: Total number of records to split (required for dict topics or None topics)
records_per_topic: Number of records per topic (default: 20)
user_instructions: Topic generation instructions (required if topics is None)
user_instruction: Topic generation instructions (required if topics is None)
model_name: Model name for topic generation
model_kwargs: Model kwargs for topic generation

Expand All @@ -89,11 +89,11 @@ async def prepare_topic(
model_kwargs["temperature"] = 1
# --- STEP 1: Input validation and normalization ---
if topics is None:
# Must have num_records and user_instructions if no topics provided
# Must have num_records and user_instruction if no topics provided
if not num_records or num_records <= 0:
raise ValueError("num_records must be positive when topics are not provided")
if not user_instructions:
raise ValueError("user_instructions required when topics are not provided")
if not user_instruction:
raise ValueError("user_instruction required when topics are not provided")
topic_assignments = []
else:
# Validate topics is a non-empty list
Expand Down Expand Up @@ -181,11 +181,11 @@ async def prepare_topic(
raise ValueError("records_per_topic must be positive when generating topics")

# Generate topics with LLM if instructions provided
if user_instructions:
if user_instruction:
topics_needed = math.ceil(remaining_records / records_per_topic)

generated = await generate_topics(
user_instructions=user_instructions, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names
user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names
)

# Assign counts to generated topics
Expand Down Expand Up @@ -237,7 +237,7 @@ async def prepare_topic(
# Example 1: Dictionary topics with additional generation
print("\nExample 1: Dictionary topics + generation")
topics1 = [{"topic1": 20}, {"topic2": 30}]
result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instructions="some context"))
result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context"))
print(f"Result: {result1}")
print(f"Total: {len(result1)}")

Expand All @@ -251,7 +251,7 @@ async def prepare_topic(
# Example 3: Mixed string and dict topics
print("\nExample 3: Mixed string/dict topics")
topics3 = ["topicX", {"topicY": 10}]
result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instructions="mixed topics"))
result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics"))
print(f"Result: {result3}")
print(f"Total: {len(result3)}")

Expand All @@ -266,7 +266,7 @@ async def prepare_topic(
print("\nExample 5: No topics, generate all")

async def run_example5():
result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instructions="cloud computing")
result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing")
print(f"Result: {result}")
print(f"Total: {len(result)}")

Expand Down
39 changes: 39 additions & 0 deletions src/starfish/data_factory/utils/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,45 @@ def to_json(self):

return json.dumps(self.to_dict(), indent=2)

def update(self, data: dict):
"""Update the configuration from a dictionary.

This method updates the configuration fields from a dictionary. It handles
the deserialization of callable functions using cloudpickle.

Args:
data (dict): Dictionary containing the configuration data. Expected keys:
- storage: Storage type string
- master_job_id: Unique job identifier
- project_id: Project identifier
- batch_size: Number of records per batch
- target_count: Total records to process
- max_concurrency: Maximum concurrent tasks
- show_progress: Whether to show progress
- task_runner_timeout: Task timeout in seconds
- on_record_complete: List of callable strings for record completion
- on_record_error: List of callable strings for record errors
- run_mode: Execution mode string
- job_run_stop_threshold: Job retry threshold

Raises:
ValueError: If invalid fields are provided
cloudpickle.PickleError: If callable deserialization fails
"""
if not isinstance(data, dict):
raise ValueError("Input must be a dictionary")

# Handle callable deserialization
if "on_record_complete" in data:
self.on_record_complete = [cloudpickle.loads(bytes.fromhex(c)) if c else None for c in data["on_record_complete"]]
if "on_record_error" in data:
self.on_record_error = [cloudpickle.loads(bytes.fromhex(c)) if c else None for c in data["on_record_error"]]

# Update other fields
for key, value in data.items():
if key not in ["on_record_complete", "on_record_error"] and hasattr(self, key):
setattr(self, key, value)


@dataclass
class FactoryJobConfig:
Expand Down
2 changes: 1 addition & 1 deletion src/starfish/data_factory/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, initial_data: Optional[Dict[str, Any]] = None):
the state will be initialized with a copy of this dictionary.
"""
super().__init__()
self._lock = threading.Lock() # Changed to threading.Lock
self._lock = threading.RLock() # Changed to threading.Lock
if initial_data is not None:
self._data = initial_data.copy()

Expand Down
90 changes: 69 additions & 21 deletions src/starfish/data_template/template_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib.util
import ast
from typing import Any, Union, List, Dict
import inspect

from starfish.data_template.utils.error import DataTemplateValueError, ImportModuleError, ImportPackageError

Expand Down Expand Up @@ -53,41 +54,88 @@ def __init__(
# Add run method if not present
if not hasattr(self.func, "run"):
self.func.run = lambda *args, **kwargs: self.func(*args, **kwargs)
if not self.input_schema:
raise DataTemplateValueError("missing input_schema_validator")

# Detect if function expects a single input_data parameter
self.expects_model_input = False
try:
self.func_signature = inspect.signature(self.func)

# Check if the function has a single parameter that matches the input_schema type
if len(self.func_signature.parameters) == 1:
param_name, param = next(iter(self.func_signature.parameters.items()))
# Check if the parameter has an annotation matching input_schema
if param.annotation == self.input_schema:
self.expects_model_input = True
except (TypeError, ValueError):
# Can't get signature (e.g., for data_factory decorated functions)
# Default to False (use individual parameters) which is safer
self.expects_model_input = False

# Check dependencies on initialization
# if self.dependencies:
# _check_dependencies(self.dependencies)

def run(self, *args, **kwargs) -> Any:
"""Execute the wrapped function with schema validation."""
# Pre-run hook: Validate input schema
try:
# Validate input against schema
if self.input_schema:
if args:
self.input_schema.validate(args[0])
elif kwargs:
self.input_schema.validate(kwargs)
except pydantic.ValidationError as e:
raise DataTemplateValueError(f"Input validation failed: {str(e)}")
async def run(self, *args, **kwargs) -> Any:
"""Execute the wrapped function with schema validation.

This method supports multiple calling patterns:
- template.run(param1=val1, param2=val2) # Keyword arguments
- template.run({"param1": val1, "param2": val2}) # Dictionary
- template.run(model_instance) # Pydantic model instance

# Execute the function
The template function can be defined in two ways:
1. Taking a single Pydantic model parameter: func(input_data: Model)
2. Taking individual parameters: func(param1, param2, param3)

In all cases, validation happens through the Pydantic model.
"""
# STEP 1: Get a validated Pydantic model instance from the inputs
validated_model = self._get_validated_model(args, kwargs)

# STEP 2: Call the function with appropriate arguments based on its signature
try:
result = self.func.run(*args, **kwargs)
if self.expects_model_input:
# Pass the whole model to functions expecting a model parameter
result = await self.func.run(validated_model)
else:
# Expand model fields for functions expecting individual parameters
result = await self.func.run(**validated_model.model_dump())
except Exception as e:
raise DataTemplateValueError(f"Template execution failed: {str(e)}")

# Post-run hook: Validate output schema
# STEP 3: Validate the output if an output schema is provided
if self.output_schema is not None:
try:
self.output_schema.validate(result)
# Use model_validate instead of validate (which is deprecated in Pydantic v2)
self.output_schema.model_validate(result)
except pydantic.ValidationError as e:
raise DataTemplateValueError(f"Output validation failed: {str(e)}")

return result

def _get_validated_model(self, args, kwargs):
"""Convert input arguments into a validated Pydantic model instance."""
# Case 1: User passed a model instance directly
if len(args) == 1 and isinstance(args[0], self.input_schema):
return args[0]

# Case 2: User passed a dictionary as the first positional argument
if len(args) == 1 and isinstance(args[0], dict):
# Merge dictionary with any keyword arguments
input_data = {**args[0], **kwargs}
# Case 3: User passed keyword arguments only
elif not args:
input_data = kwargs
# Case 4: Invalid input (multiple positional args or wrong type)
else:
raise DataTemplateValueError("Invalid arguments: Please provide either keyword arguments, " "a single dictionary, or a model instance.")

# Validate and return a model instance
try:
return self.input_schema.model_validate(input_data)
except pydantic.ValidationError as e:
raise DataTemplateValueError(f"Input validation failed: {str(e)}")


# ====================
# Template Management Class
Expand All @@ -107,11 +155,11 @@ def list(is_detail: bool = False) -> list[Any]:
if len(result) == 0:
# Walk through all subdirectories in templates folder
for subdir in templates_dir.iterdir():
if subdir.is_dir():
for sub_subdir in subdir.iterdir():
# Find all .py files in the subdirectory
for template_file in subdir.glob("*.py"):
for template_file in sub_subdir.glob("*.py"):
try:
module_name = f"starfish.data_template.templates.{subdir.name}.{template_file.stem}"
module_name = f"starfish.data_template.templates.{subdir.name}.{sub_subdir.name}.{template_file.stem}"
# Parse the file's AST to extract decorator information
with open(template_file, "r") as f:
tree = ast.parse(f.read())
Expand Down
62 changes: 0 additions & 62 deletions src/starfish/data_template/templates/community/topic_generator.py

This file was deleted.

Loading