From c0b539515267e6b193dacf262cc459646015fcd5 Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 14:55:35 -0700 Subject: [PATCH 1/7] data_factory : code enhancement and bug fix --- examples/simple_feedback_loop.py | 2 +- examples/test_langgraph.py | 2 +- examples/test_langgraph_structured_llm.py | 2 +- examples/trial_llm.py | 39 +++--- src/starfish/common/decorator.py | 13 -- src/starfish/data_factory/config.py | 2 + src/starfish/data_factory/constants.py | 5 - src/starfish/data_factory/event_loop.py | 19 --- src/starfish/data_factory/factory.py | 113 ++++++++++-------- src/starfish/data_factory/job_manager.py | 45 +++---- src/starfish/data_factory/state.py | 47 ++++---- .../storage/local/metadata_handler.py | 21 +--- src/starfish/data_factory/task_runner.py | 4 +- src/starfish/data_factory/utils/decorator.py | 22 ++++ .../data_factory/{ => utils}/enums.py | 0 .../data_factory/{ => utils}/errors.py | 0 src/starfish/data_factory/utils/mock.py | 16 +++ tests/__init__.py | 12 ++ .../{src => }/data_factory/storage/README.md | 0 .../data_factory/storage/__init__.py | 0 .../data_factory/storage/local/__init__.py | 0 .../storage/local/test_basic_storage.py | 0 .../storage/local/test_local_storage.py | 0 .../storage/local/test_performance.py | 0 .../data_factory/storage/test_storage_main.py | 0 .../data_factory/test_data_factory.py | 31 ++--- tests/{src => }/llm/prompt/test_prompt.py | 0 27 files changed, 203 insertions(+), 192 deletions(-) delete mode 100644 src/starfish/common/decorator.py create mode 100644 src/starfish/data_factory/config.py create mode 100644 src/starfish/data_factory/utils/decorator.py rename src/starfish/data_factory/{ => utils}/enums.py (100%) rename src/starfish/data_factory/{ => utils}/errors.py (100%) create mode 100644 src/starfish/data_factory/utils/mock.py create mode 100644 tests/__init__.py rename tests/{src => }/data_factory/storage/README.md (100%) rename tests/{src => }/data_factory/storage/__init__.py (100%) rename tests/{src => }/data_factory/storage/local/__init__.py (100%) rename tests/{src => }/data_factory/storage/local/test_basic_storage.py (100%) rename tests/{src => }/data_factory/storage/local/test_local_storage.py (100%) rename tests/{src => }/data_factory/storage/local/test_performance.py (100%) rename tests/{src => }/data_factory/storage/test_storage_main.py (100%) rename tests/{src => }/data_factory/test_data_factory.py (86%) rename tests/{src => }/llm/prompt/test_prompt.py (100%) diff --git a/examples/simple_feedback_loop.py b/examples/simple_feedback_loop.py index edc4979..d9758e8 100644 --- a/examples/simple_feedback_loop.py +++ b/examples/simple_feedback_loop.py @@ -4,7 +4,7 @@ from starfish import StructuredLLM, data_factory from starfish.data_factory.constants import RECORD_STATUS -from starfish.data_factory.enums import RecordStatus +from starfish.data_factory.utils.enums import RecordStatus # Create a StructuredLLM instance for city information generation city_facts_llm = StructuredLLM( diff --git a/examples/test_langgraph.py b/examples/test_langgraph.py index f91d481..521f725 100644 --- a/examples/test_langgraph.py +++ b/examples/test_langgraph.py @@ -8,7 +8,7 @@ from starfish import StructuredLLM, data_factory from starfish.data_factory.constants import RECORD_STATUS -from starfish.data_factory.enums import RecordStatus +from starfish.data_factory.utils.enums import RecordStatus # Define a simple tool @tool diff --git a/examples/test_langgraph_structured_llm.py b/examples/test_langgraph_structured_llm.py index 114303f..83c269f 100644 --- a/examples/test_langgraph_structured_llm.py +++ b/examples/test_langgraph_structured_llm.py @@ -8,7 +8,7 @@ from starfish import StructuredLLM, data_factory from starfish.data_factory.constants import RECORD_STATUS -from starfish.data_factory.enums import RecordStatus +from starfish.data_factory.utils.enums import RecordStatus # Define a simple tool @tool diff --git a/examples/trial_llm.py b/examples/trial_llm.py index 609d992..b916908 100644 --- a/examples/trial_llm.py +++ b/examples/trial_llm.py @@ -6,42 +6,33 @@ from starfish.data_factory.constants import STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED, STORAGE_TYPE_IN_MEMORY, STORAGE_TYPE_LOCAL from starfish.data_factory.state import MutableSharedState from starfish.common.logger import get_logger +from starfish.data_factory.utils.mock import mock_llm_call logger = get_logger(__name__) # Add callback for error handling # todo state is a class with thread safe dict -async def handle_error(data: Any, state: MutableSharedState): +def handle_error(data: Any, state: MutableSharedState): logger.error(f"Error occurred: {data}") return STATUS_FAILED -async def handle_record_complete(data: Any, state: MutableSharedState): - print(f"Record complete: {data}") +def handle_record_complete(data: Any, state: MutableSharedState): + #print(f"Record complete: {data}") - await state.set("completed_count", 1) - await state.data - await state.update({"completed_count": 2}) + state.set("completed_count", 1) + state_data = state.data + state.update({"completed_count": 2}) return STATUS_COMPLETED -async def handle_duplicate_record(data: Any, state: MutableSharedState): +def handle_duplicate_record(data: Any, state: MutableSharedState): logger.debug(f"Record duplicated: {data}") - await state.set("completed_count", 1) - await state.data - await state.update({"completed_count": 2}) + state.set("completed_count", 1) + state_data = state.data + state.update({"completed_count": 2}) #return STATUS_DUPLICATE if random.random() < 0.9: return STATUS_COMPLETED return STATUS_DUPLICATE -async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.5, sleep_time=0.1): - await asyncio.sleep(sleep_time) - if random.random() < fail_rate: - logger.debug(f" {city_name}: Failed!") - raise ValueError(f"Mock LLM failed to process city: {city_name}") - - logger.debug(f"{city_name}: Successfully processed!") - - result = [{"answer": f"{city_name}_{random.randint(1, 5)}"} for _ in range(num_records_per_city)] - return result @data_factory( storage=STORAGE_TYPE_LOCAL, max_concurrency=50, initial_state_values={}, on_record_complete=[handle_record_complete, handle_duplicate_record], @@ -72,7 +63,7 @@ async def get_city_info_wf(city_name, region_code): # output = await validation_llm.run(data=output.data) #return output.data - return await mock_llm_call(city_name, num_records_per_city=3, fail_rate=0.5, sleep_time=0.01) + return await mock_llm_call(city_name, num_records_per_city=3, fail_rate=0.01, sleep_time=1) # Execute with batch processing @@ -87,8 +78,8 @@ async def get_city_info_wf(city_name, region_code): results = get_city_info_wf.run( #data=[{"city_name": "Berlin"}, {"city_name": "Rome"}], #[{"city_name": "Berlin"}, {"city_name": "Rome"}], - city_name=["San Francisco", "New York", "Los Angeles"]*10, - region_code=["DE", "IT", "US"]*10, + city_name=["San Francisco", "New York", "Los Angeles"]*50, + region_code=["DE", "IT", "US"]*50, # city_name="Beijing", ### Overwrite the data key # num_records_per_city = 3 ) @@ -102,7 +93,7 @@ async def get_city_info_wf(city_name, region_code): # num_records_per_city = 3 ) elif user_case == "re_run": - results = get_city_info_wf.re_run( master_job_id="e342bb94-3784-45c7-beab-4e01cb059f1c") + results = get_city_info_wf.re_run( master_job_id="05668e16-6f47-4ccf-9f25-4ff7b7030bdb") #logger.info(f"Results: {results}") diff --git a/src/starfish/common/decorator.py b/src/starfish/common/decorator.py deleted file mode 100644 index 7022561..0000000 --- a/src/starfish/common/decorator.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Callable -import asyncio -from starfish.data_factory.constants import STORAGE_TYPE_LOCAL - -def storage_action(): - """Decorator to handle storage-specific async operations""" - # to be replaced by the registery pattern - def decorator(func: Callable): - def wrapper(self, *args, **kwargs): - if self.storage == STORAGE_TYPE_LOCAL: - return asyncio.run(func(self, *args, **kwargs)) - return wrapper - return decorator \ No newline at end of file diff --git a/src/starfish/data_factory/config.py b/src/starfish/data_factory/config.py new file mode 100644 index 0000000..8fe1148 --- /dev/null +++ b/src/starfish/data_factory/config.py @@ -0,0 +1,2 @@ +PROGRESS_LOG_INTERVAL = 3 +TASK_RUNNER_TIMEOUT = 30 \ No newline at end of file diff --git a/src/starfish/data_factory/constants.py b/src/starfish/data_factory/constants.py index 51d666d..ed0285d 100644 --- a/src/starfish/data_factory/constants.py +++ b/src/starfish/data_factory/constants.py @@ -1,9 +1,6 @@ import os import sys from pathlib import Path -from starfish.common.logger import get_logger - -logger = get_logger(__name__) RECORD_STATUS = "status" @@ -28,9 +25,7 @@ STORAGE_TYPE_LOCAL = "local" STORAGE_TYPE_IN_MEMORY = "in_memory" -PROGRESS_LOG_INTERVAL = 3 -TASK_RUNNER_TIMEOUT = 30 # Define the function directly in constants to avoid circular imports def get_app_data_dir(): diff --git a/src/starfish/data_factory/event_loop.py b/src/starfish/data_factory/event_loop.py index e2ee40c..904c5c9 100644 --- a/src/starfish/data_factory/event_loop.py +++ b/src/starfish/data_factory/event_loop.py @@ -4,25 +4,6 @@ def run_in_event_loop(coroutine): - """Run a coroutine in the current event loop or create a new one if there isn't one.""" - # First, clean up any existing Rich live displays - # This prevents "Only one live display may be active at once" errors - # especially after keyboard interrupts in environments like Colab - try: - # Get all live displays from all console instances - from rich import get_console - - console = get_console() - if hasattr(console, "_live") and console._live is not None: - try: - console._live.stop() - console._live = None - except Exception: - # If stopping fails, just set to None - console._live = None - except Exception: - # If any error occurs during cleanup, just continue - pass try: # This call will raise an RuntimError if there is no event loop running. diff --git a/src/starfish/data_factory/factory.py b/src/starfish/data_factory/factory.py index de9d7de..813529a 100644 --- a/src/starfish/data_factory/factory.py +++ b/src/starfish/data_factory/factory.py @@ -7,20 +7,18 @@ from queue import Queue from typing import Any, Callable, Dict, List from inspect import signature, Parameter -from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, TaskProgressColumn -from starfish.data_factory.event_loop import run_in_event_loop +from rich.progress import Progress, TextColumn from starfish.data_factory.job_manager import JobManager -from starfish.data_factory.constants import RECORD_STATUS, RUN_MODE_DRY_RUN, STATUS_MOJO_MAP, STATUS_TOTAL, TASK_RUNNER_TIMEOUT, LOCAL_STORAGE_URI, STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED, RUN_MODE, RUN_MODE_RE_RUN, RUN_MODE_NORMAL, PROGRESS_LOG_INTERVAL +from starfish.data_factory.constants import RECORD_STATUS, RUN_MODE_DRY_RUN,LOCAL_STORAGE_URI, STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED, RUN_MODE, RUN_MODE_RE_RUN +from starfish.data_factory.config import PROGRESS_LOG_INTERVAL, TASK_RUNNER_TIMEOUT from starfish.data_factory.storage.local.local_storage import LocalStorage from starfish.data_factory.storage.in_memory.in_memory_storage import InMemoryStorage from starfish.data_factory.state import MutableSharedState from starfish.data_factory.storage.models import ( - GenerationJob, GenerationMasterJob, Project, - Record, ) -from starfish.common.decorator import storage_action +from starfish.data_factory.utils.decorator import async_wrapper from starfish.common.logger import get_logger logger = get_logger(__name__) @@ -57,6 +55,7 @@ def __init__( self.factory_storage = None self.func = None self.master_job_id = None + self.err = None def __call__(self, func: Callable): # self.job_manager.add_job(func) @@ -64,7 +63,6 @@ def __call__(self, func: Callable): @wraps(func) def wrapper(*args, **kwargs): run_mode = self.job_config.get(RUN_MODE) - err = None try: # Check for master_job_id in kwargs and assign if present @@ -72,7 +70,7 @@ def wrapper(*args, **kwargs): self._setup_storage_and_job_manager() self._set_input_data_from_master_job() - self._init_progress_bar_update_job_config() + self._update_job_config() elif run_mode == RUN_MODE_DRY_RUN: # dry run mode self.input_data = self.input_converter(*args, **kwargs) @@ -95,7 +93,7 @@ def wrapper(*args, **kwargs): self._log_master_job_start() # Start progress bar before any operations # Process batches and keep progress bar alive - self._init_progress_bar_update_job_config() + self._update_job_config() self._update_master_job_status() self._process_batches() @@ -104,20 +102,19 @@ def wrapper(*args, **kwargs): if len(result) == 0: raise ValueError("No records generated") return result - except TypeError as e: - err = e - logger.error(f"TypeError occurred: {str(e)}") - raise - except ValueError as e: - err = e - logger.error(f"ValueError occurred: {str(e)}") - raise + except (TypeError, ValueError, KeyboardInterrupt) as e: + self.err = e + raise e finally: + self._complete_master_job() + self._close_storage() # Only execute finally block if not TypeError - if err is None and run_mode != RUN_MODE_DRY_RUN: - self._complete_master_job() - self._close_storage() - self.show_job_progress_status() + if self.err: + logger.error(f"Error occurred: {self.err}") + logger.error(f"Please rerun the job with master_job_id {self.master_job_id}") + else: + if run_mode != RUN_MODE_DRY_RUN: + self.show_job_progress_status() # Add run method to the wrapped function def run(*args, **kwargs): if 'master_job_id' in kwargs: @@ -183,19 +180,19 @@ def _process_batches(self) -> List[Any]: """Process batches with asyncio""" logger.info( f"[JOB PROGRESS] " - f"\033[1mJob started:\033[0m " + f"\033[1mJob Started:\033[0m " f"\033[36mMaster Job ID: {self.master_job_id}\033[0m | " f"\033[33mLogging progress every {PROGRESS_LOG_INTERVAL} seconds\033[0m" ) return self.job_manager.run_orchestration() - @storage_action() + @async_wrapper() async def _save_project(self): project = Project(project_id=self.project_id, name="Test Project", description="A test project for storage layer testing") await self.factory_storage.save_project(project) - @storage_action() + @async_wrapper() async def _save_request_config(self): logger.debug("\n2. Creating master job...") # First save the request config @@ -204,7 +201,7 @@ async def _save_request_config(self): logger.debug(f" - Saved request config to: {self.config_ref}") - @storage_action() + @async_wrapper() async def _set_input_data_from_master_job(self): master_job = await self.factory_storage.get_master_job(self.master_job_id) if master_job: @@ -232,7 +229,7 @@ async def _set_input_data_from_master_job(self): self.job_config[RUN_MODE] = RUN_MODE_RE_RUN - @storage_action() + @async_wrapper() async def _log_master_job_start(self): # Now create the master job master_job = GenerationMasterJob( @@ -249,7 +246,7 @@ async def _log_master_job_start(self): logger.debug(f" - Created master job: {master_job.name} ({self.master_job_id})") - @storage_action() + @async_wrapper() async def _update_master_job_status(self): now = datetime.datetime.now(datetime.timezone.utc) await self.factory_storage.update_master_job_status(self.master_job_id, "running", now) @@ -257,23 +254,28 @@ async def _update_master_job_status(self): - @storage_action() + @async_wrapper() async def _complete_master_job(self): # Complete the master job - logger.debug("\n7. Completing master job...") + logger.debug("\n7. Stopping master job...") now = datetime.datetime.now(datetime.timezone.utc) - #todo : how to collect all the execution job status? - summary = {STATUS_COMPLETED: self.job_manager.completed_count, + status = STATUS_FAILED if self.err else STATUS_COMPLETED + if self.err: + summary = {} + else: + summary = {STATUS_COMPLETED: self.job_manager.completed_count, STATUS_FILTERED: self.job_manager.filtered_count, STATUS_DUPLICATE: self.job_manager.duplicate_count, STATUS_FAILED: self.job_manager.failed_count} - await self.factory_storage.log_master_job_end(self.master_job_id, STATUS_COMPLETED, summary, now, now) - logger.info(f"Master job {self.master_job_id} as completed") + if self.factory_storage: + await self.factory_storage.log_master_job_end(self.master_job_id, status, summary, now, now) + logger.info(f"Master Job {self.master_job_id} has been ended") - @storage_action() + @async_wrapper() async def _close_storage(self): - await self.factory_storage.close() + if self.factory_storage: + await self.factory_storage.close() def storage_setup(self): @@ -284,9 +286,6 @@ def storage_setup(self): self.factory_storage = InMemoryStorage() asyncio.run(self.factory_storage.setup()) - def _init_progress_bar_update_job_config(self): - self._update_job_config() - self._init_progress_bar() def _update_job_config(self): @@ -346,22 +345,30 @@ def _init_progress_bar(self): def show_job_progress_status(self): target_count = self.job_config.get("target_count") - #logger.info(f"Job finished. Final Stats: Completed: {self.job_manager.completed_count}/{target_count} | Attempted: {self.job_manager.total_count} (Failed: {self.job_manager.failed_count}, Filtered: {self.job_manager.filtered_count}, Duplicate: {self.job_manager.duplicate_count})") - if self.job_config.get("show_progress"): - self.progress.start() + logger.info( + f"[JOB PROGRESS] " + f"\033[1mJob Finished:\033[0m " + f"\033[32mCompleted: {self.job_manager.completed_count}/{target_count}\033[0m | " + f"\033[33mAttempted: {self.job_manager.total_count}\033[0m " + f"(Failed: {self.job_manager.failed_count}, " + f"Filtered: {self.job_manager.filtered_count}, " + f"Duplicate: {self.job_manager.duplicate_count})" + ) + # if self.job_config.get("show_progress"): + # self.progress.start() - for counter_type, task_id in self.progress_tasks.items(): - count = getattr(self.job_manager, f"{counter_type}_count") - emoji = STATUS_MOJO_MAP[counter_type] - percentage = int((count / target_count) * 100) if target_count > 0 else 0 - if counter_type != STATUS_COMPLETED: - target_count = self.job_manager.total_count - self.progress.update( - task_id, - completed=count, - status=f"{emoji} {count}/{target_count} ({percentage}%)" - ) - self.progress.stop() + # for counter_type, task_id in self.progress_tasks.items(): + # count = getattr(self.job_manager, f"{counter_type}_count") + # emoji = STATUS_MOJO_MAP[counter_type] + # percentage = int((count / target_count) * 100) if target_count > 0 else 0 + # if counter_type != STATUS_COMPLETED: + # target_count = self.job_manager.total_count + # self.progress.update( + # task_id, + # completed=count, + # status=f"{emoji} {count}/{target_count} ({percentage}%)" + # ) + # self.progress.stop() def default_input_converter(data : List[Dict[str, Any]]=[], **kwargs) -> Queue[Dict[str, Any]]: diff --git a/src/starfish/data_factory/job_manager.py b/src/starfish/data_factory/job_manager.py index a9f6129..17ada9c 100644 --- a/src/starfish/data_factory/job_manager.py +++ b/src/starfish/data_factory/job_manager.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List import uuid from queue import Queue -from starfish.data_factory.errors import DuplicateRecordError, FilterRecordError, RecordError -from starfish.data_factory.constants import RECORD_STATUS, STATUS_MOJO_MAP, RUN_MODE, RUN_MODE_RE_RUN, RUN_MODE_DRY_RUN, PROGRESS_LOG_INTERVAL +from starfish.data_factory.utils.errors import DuplicateRecordError, FilterRecordError, RecordError +from starfish.data_factory.constants import RECORD_STATUS, RUN_MODE, RUN_MODE_RE_RUN, RUN_MODE_DRY_RUN +from starfish.data_factory.config import PROGRESS_LOG_INTERVAL from starfish.data_factory.event_loop import run_in_event_loop from starfish.data_factory.task_runner import TaskRunner from starfish.data_factory.constants import STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED @@ -18,9 +19,7 @@ ) from starfish.data_factory.storage.base import Storage from starfish.common.logger import get_logger - logger = get_logger(__name__) -# from starfish.common.logger_new import logger class JobManager: def __init__(self, job_config: Dict[str, Any], storage: Storage): @@ -125,7 +124,8 @@ async def _async_run_orchestration_re_run(self): records_metadata = await self.storage.list_record_metadata(self.job_config["master_job_id"], task.job_id) for record in records_metadata: record_data = await self.storage.get_record_data(record.output_ref) - self.job_output.put(record_data) + output_tmp = {RECORD_STATUS: STATUS_COMPLETED, "output": record_data} + self.job_output.put(output_tmp) self.total_count += 1 self.completed_count += 1 # run the rest of the tasks @@ -139,10 +139,10 @@ async def _progress_ticker(self): """Log a message every 5 seconds""" while not self.is_job_to_stop(): logger.info( - f"\033[1mProgress:\033[0m " + f"[JOB PROGRESS] " f"\033[32mCompleted: {self.completed_count}/{self.target_count}\033[0m | " - f"\033[33mRunning: {self.semaphore._value}\033[0m | " - f"\033[36mAttempted: {self.total_count}\033[0m\n" + f"\033[33mRunning: {self.job_config.get('max_concurrency') - self.semaphore._value}\033[0m | " + f"\033[36mAttempted: {self.total_count}\033[0m" f" (\033[32mCompleted: {self.completed_count}\033[0m, " f"\033[31mFailed: {self.failed_count}\033[0m, " f"\033[35mFiltered: {self.filtered_count}\033[0m, " @@ -154,6 +154,8 @@ async def _async_run_orchestration(self): """Main orchestration loop for the job""" # Start the ticker task _progress_ticker_task = asyncio.create_task(self._progress_ticker()) + # Store all running tasks + running_tasks = set() try: while not self.is_job_to_stop(): @@ -163,7 +165,9 @@ async def _async_run_orchestration(self): await self.semaphore.acquire() logger.debug(f"Semaphore acquired, waiting for task to complete") input_data = self.job_input_queue.get() - self._create_single_task(input_data) + task = self._create_single_task(input_data) + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) else: await asyncio.sleep(1) finally: @@ -173,11 +177,19 @@ async def _async_run_orchestration(self): await _progress_ticker_task except asyncio.CancelledError: pass - - def _create_single_task(self, input_data): + + # Cancel all running tasks + # todo whether openai call will close + for task in running_tasks: + task.cancel() + # Wait for all tasks to be cancelled + await asyncio.gather(*running_tasks, return_exceptions=True) + + def _create_single_task(self, input_data) -> asyncio.Task: task = asyncio.create_task(self._run_single_task(input_data)) asyncio.create_task(self._handle_task_completion(task)) logger.debug(f"Task created, waiting for task to complete") + return task async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: """Run a single task with error handling and storage""" @@ -190,7 +202,7 @@ async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: hooks_output = [] # class based hooks. use semaphore to ensure thread safe for hook in self.job_config.get("on_record_complete", []): - hooks_output.append(await hook(output, self.state)) + hooks_output.append(hook(output, self.state)) if hooks_output.count(STATUS_DUPLICATE) > 0: # duplicate filtered need retry task_status = STATUS_DUPLICATE @@ -200,7 +212,7 @@ async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: except Exception as e: logger.error(f"Error running task: {e}") for hook in self.job_config.get("on_record_error", []): - await hook(str(e), self.state) + hook(str(e), self.state) task_status = STATUS_FAILED finally: # if task is not completed, put the input data back to the job input queue @@ -238,10 +250,3 @@ def update_job_config(self, job_config: Dict[str, Any]): """Update job config by merging new values with existing config""" self.job_config = {**self.job_config, **job_config} - # async def _update_progress(self, counter_type: str, emoji: str): - # """Update counters without showing live progress""" - # if not self.show_progress: - # return - # # Update the internal counters - # async with self.progress_lock: - # self._counters[counter_type] = getattr(self, f"{counter_type}_count") diff --git a/src/starfish/data_factory/state.py b/src/starfish/data_factory/state.py index 70901c3..fcc64e0 100644 --- a/src/starfish/data_factory/state.py +++ b/src/starfish/data_factory/state.py @@ -1,58 +1,61 @@ +import threading from pydantic import BaseModel from typing import Dict, Any, Optional -import asyncio - +#from starfish.data_factory.utils.decorator import async_to_sync_event_loop class MutableSharedState(BaseModel): _data: Dict[str, Any] = {} #If you want each MutableSharedState instance to have its own independent # synchronization, you should move the lock initialization into __init__. def __init__(self, initial_data: Optional[Dict[str, Any]] = None): super().__init__() - self._lock = asyncio.Lock() # Instance-level lock + self._lock = threading.Lock() # Instance-level lock if initial_data is not None: self._data = initial_data.copy() # Use data when you want to emphasize you're accessing the current state @property - async def data(self) -> Dict[str, Any]: - return await self.to_dict() + def data(self) -> Dict[str, Any]: + return self.to_dict() @data.setter - async def data(self, value: Dict[str, Any]) -> None: - async with self._lock: + def data(self, value: Dict[str, Any]) -> None: + with self._lock: self._data = value.copy() - async def get(self, key: str) -> Any: - async with self._lock: + + def get(self, key: str) -> Any: + with self._lock: return self._data.get(key) - - async def set(self, key: str, value: Any) -> None: - async with self._lock: + + def set(self, key: str, value: Any) -> None: + with self._lock: self._data[key] = value - async def update(self, updates: Dict[str, Any]) -> None: - async with self._lock: + + def update(self, updates: Dict[str, Any]) -> None: + with self._lock: self._data.update(updates) # Use to_dict when you want to emphasize you're converting/serializing the state - async def to_dict(self) -> Dict[str, Any]: - async with self._lock: + + def to_dict(self) -> Dict[str, Any]: + with self._lock: return self._data.copy() # # Set the entire state -# await state.data = {"key": "value"} +# state.data = {"key": "value"} # # Get the entire state -# current_state = await state.data +# current_state = state.data # # Set a value -# await state.set("key", "value") +# state.set("key", "value") # # Get a value -# value = await state.get("key") +# value = state.get("key") # # Update multiple values -# await state.update({"key1": "value1", "key2": "value2"}) +# state.update({"key1": "value1", "key2": "value2"}) # # Get a copy of the entire state -# state_dict = await state.to_dict() +# state_dict = state.to_dict() diff --git a/src/starfish/data_factory/storage/local/metadata_handler.py b/src/starfish/data_factory/storage/local/metadata_handler.py index df05560..47f29a5 100644 --- a/src/starfish/data_factory/storage/local/metadata_handler.py +++ b/src/starfish/data_factory/storage/local/metadata_handler.py @@ -100,10 +100,9 @@ async def _execute_sql(self, sql: str, params: tuple = ()): async with self._write_lock: conn = await self.connect() try: - # Use context manager for cursor and transaction - async with conn.execute("BEGIN IMMEDIATE;") as cursor: - await cursor.execute(sql, params) - await conn.commit() + # Remove the explicit BEGIN IMMEDIATE since it's handled by the connection + async with conn.execute(sql, params) as _: + await conn.commit() logger.debug(f"Executed write SQL: {sql[:50]}... Params: {params}") except Exception as e: try: @@ -114,23 +113,13 @@ async def _execute_sql(self, sql: str, params: tuple = ()): raise e async def _execute_batch_sql(self, statements: List[Tuple[str, tuple]]): - """Execute multiple SQL statements in a single transaction. - - Args: - statements: List of (sql, params) tuples to execute in one transaction - """ - # Use the write lock to ensure only one write transaction at a time + """Execute multiple SQL statements in a single transaction.""" async with self._write_lock: conn = await self.connect() try: - # Begin a single transaction for all statements - await conn.execute("BEGIN IMMEDIATE") - - # Execute all statements + # Remove explicit BEGIN IMMEDIATE for sql, params in statements: await conn.execute(sql, params) - - # Commit the transaction await conn.commit() logger.debug(f"Executed batch SQL: {len(statements)} statements") except Exception as e: diff --git a/src/starfish/data_factory/task_runner.py b/src/starfish/data_factory/task_runner.py index c7e6da6..f956f2f 100644 --- a/src/starfish/data_factory/task_runner.py +++ b/src/starfish/data_factory/task_runner.py @@ -1,7 +1,7 @@ import asyncio import time from typing import Any, Callable, Dict, List -from starfish.data_factory.constants import TASK_RUNNER_TIMEOUT +from starfish.data_factory.config import TASK_RUNNER_TIMEOUT from starfish.common.logger import get_logger logger = get_logger(__name__) #from starfish.common.logger_new import logger @@ -28,6 +28,6 @@ async def run_task(self, func: Callable, input_data: Dict) -> List[Any]: except Exception as e: retries += 1 if retries > self.max_retries: - logger.error(f"Task execution failed after {self.max_retries} retries") + #logger.error(f"Task execution failed after {self.max_retries} retries") raise e await asyncio.sleep(2**retries) # exponential backoff diff --git a/src/starfish/data_factory/utils/decorator.py b/src/starfish/data_factory/utils/decorator.py new file mode 100644 index 0000000..4cb3282 --- /dev/null +++ b/src/starfish/data_factory/utils/decorator.py @@ -0,0 +1,22 @@ +from typing import Callable +import asyncio +#from starfish.data_factory.constants import STORAGE_TYPE_LOCAL +from starfish.data_factory.event_loop import run_in_event_loop + +def async_wrapper(): + """Decorator to handle storage-specific async operations""" + # to be replaced by the registery pattern + def decorator(func: Callable): + def wrapper(self, *args, **kwargs): + #if self.storage == STORAGE_TYPE_LOCAL: + return asyncio.run(func(self, *args, **kwargs)) + return wrapper + return decorator + +def async_to_sync_event_loop(): + """Decorator to handle storage-specific async operations""" + def decorator(func: Callable): + def wrapper(self, *args, **kwargs): + return run_in_event_loop(func(self, *args, **kwargs)) + return wrapper + return decorator \ No newline at end of file diff --git a/src/starfish/data_factory/enums.py b/src/starfish/data_factory/utils/enums.py similarity index 100% rename from src/starfish/data_factory/enums.py rename to src/starfish/data_factory/utils/enums.py diff --git a/src/starfish/data_factory/errors.py b/src/starfish/data_factory/utils/errors.py similarity index 100% rename from src/starfish/data_factory/errors.py rename to src/starfish/data_factory/utils/errors.py diff --git a/src/starfish/data_factory/utils/mock.py b/src/starfish/data_factory/utils/mock.py new file mode 100644 index 0000000..4a7749a --- /dev/null +++ b/src/starfish/data_factory/utils/mock.py @@ -0,0 +1,16 @@ +import random +import asyncio +from starfish.common.logger import get_logger +logger = get_logger(__name__) + +async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.01, sleep_time=0.5): + await asyncio.sleep(sleep_time) + + if random.random() < fail_rate: + logger.debug(f" {city_name}: Failed!") + raise ValueError(f"Mock LLM failed to process city: {city_name}") + + logger.debug(f"{city_name}: Successfully processed!") + + result = [{"answer": f"{city_name}_{random.randint(1, 5)}"} for _ in range(num_records_per_city)] + return result \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..21f3a6b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,12 @@ +import os +import sys + +# Add workspace folders to PYTHONPATH +workspace_folders = [ + os.path.join(os.path.dirname(__file__), '..', 'src'), + os.path.dirname(os.path.dirname(__file__)) +] + +for folder in workspace_folders: + if folder not in sys.path: + sys.path.insert(0, folder) diff --git a/tests/src/data_factory/storage/README.md b/tests/data_factory/storage/README.md similarity index 100% rename from tests/src/data_factory/storage/README.md rename to tests/data_factory/storage/README.md diff --git a/tests/src/data_factory/storage/__init__.py b/tests/data_factory/storage/__init__.py similarity index 100% rename from tests/src/data_factory/storage/__init__.py rename to tests/data_factory/storage/__init__.py diff --git a/tests/src/data_factory/storage/local/__init__.py b/tests/data_factory/storage/local/__init__.py similarity index 100% rename from tests/src/data_factory/storage/local/__init__.py rename to tests/data_factory/storage/local/__init__.py diff --git a/tests/src/data_factory/storage/local/test_basic_storage.py b/tests/data_factory/storage/local/test_basic_storage.py similarity index 100% rename from tests/src/data_factory/storage/local/test_basic_storage.py rename to tests/data_factory/storage/local/test_basic_storage.py diff --git a/tests/src/data_factory/storage/local/test_local_storage.py b/tests/data_factory/storage/local/test_local_storage.py similarity index 100% rename from tests/src/data_factory/storage/local/test_local_storage.py rename to tests/data_factory/storage/local/test_local_storage.py diff --git a/tests/src/data_factory/storage/local/test_performance.py b/tests/data_factory/storage/local/test_performance.py similarity index 100% rename from tests/src/data_factory/storage/local/test_performance.py rename to tests/data_factory/storage/local/test_performance.py diff --git a/tests/src/data_factory/storage/test_storage_main.py b/tests/data_factory/storage/test_storage_main.py similarity index 100% rename from tests/src/data_factory/storage/test_storage_main.py rename to tests/data_factory/storage/test_storage_main.py diff --git a/tests/src/data_factory/test_data_factory.py b/tests/data_factory/test_data_factory.py similarity index 86% rename from tests/src/data_factory/test_data_factory.py rename to tests/data_factory/test_data_factory.py index 74bd1e5..701c883 100644 --- a/tests/src/data_factory/test_data_factory.py +++ b/tests/data_factory/test_data_factory.py @@ -9,23 +9,24 @@ from starfish.common.env_loader import load_env_file from starfish import data_factory from starfish.data_factory.state import MutableSharedState +from starfish.data_factory.utils.mock import mock_llm_call load_env_file() ### Mock LLM call -async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.05, sleep_time=0.01): - # Simulate a slight delay (optional, feels more async-realistic) - await asyncio.sleep(sleep_time) +# async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.05, sleep_time=0.01): +# # Simulate a slight delay (optional, feels more async-realistic) +# await asyncio.sleep(sleep_time) - # 5% chance of failure - if random.random() < fail_rate: - print(f" {city_name}: Failed!") ## For debugging - raise ValueError(f"Mock LLM failed to process city: {city_name}") +# # 5% chance of failure +# if random.random() < fail_rate: +# print(f" {city_name}: Failed!") ## For debugging +# raise ValueError(f"Mock LLM failed to process city: {city_name}") - print(f"{city_name}: Successfully processed!") ## For debugging +# print(f"{city_name}: Successfully processed!") ## For debugging - result = [f"{city_name}_{random.randint(1, 5)}" for _ in range(num_records_per_city)] - return result +# result = [f"{city_name}_{random.randint(1, 5)}" for _ in range(num_records_per_city)] +# return result @pytest.mark.asyncio @@ -101,7 +102,7 @@ async def test_func(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0 # Verify all results contain the override value for item in result: - assert ('override_city_name' in item) + assert ('override_city_name' in item['answer']) @pytest.mark.asyncio async def test_case_5(): @@ -120,7 +121,7 @@ async def test_func(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0 ) # Verify each result contains the corresponding override value - assert any('1. override_city_name' in item or '2. override_city_name' in item for item in result) + assert any('1. override_city_name' in item["answer"] or '2. override_city_name' in item["answer"] for item in result) @pytest.mark.asyncio async def test_case_6(): @@ -163,8 +164,8 @@ async def test_case_8(): - Hook: test_hook modifies state - Expected: State variable should be modified by hook """ - async def test_hook(data, state): - await state.update({"variable": f'changed_state - {data}'}) + def test_hook(data, state): + state.update({"variable": f'changed_state - {data}'}) return STATUS_COMPLETED @@ -176,5 +177,5 @@ async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05) {'city_name': '1. New York'}, {'city_name': '2. Los Angeles'}, ], num_records_per_city=1) - state_value = await test1.state.get('variable') + state_value = test1.state.get('variable') assert state_value.startswith('changed_state') \ No newline at end of file diff --git a/tests/src/llm/prompt/test_prompt.py b/tests/llm/prompt/test_prompt.py similarity index 100% rename from tests/src/llm/prompt/test_prompt.py rename to tests/llm/prompt/test_prompt.py From cba3b91c3a32e8589179f804e3b97c4fcd3f1522 Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 15:41:56 -0700 Subject: [PATCH 2/7] cicd : add pre-commit and cicd --- Makefile | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 5d5891f..994f1eb 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,11 @@ lint: poetry run ruff format tests starfish examples test: - poetry run pytest tests/ + poetry run pytest tests/ + +install: + @echo "Installing dependencies..." + poetry install + poetry run pre-commit install From c70d550a234f78203a062402fb722038b7ce58e2 Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 15:45:53 -0700 Subject: [PATCH 3/7] cicd : add pre-commit and cicd --- .github/workflows/lint-and-test.yaml | 55 ++++++++++++++++++++++++++++ .github/workflows/pre-commit.yml | 14 +++++++ .pre-commit-config.yaml | 25 +++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 .github/workflows/lint-and-test.yaml create mode 100644 .github/workflows/pre-commit.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/lint-and-test.yaml b/.github/workflows/lint-and-test.yaml new file mode 100644 index 0000000..395cabb --- /dev/null +++ b/.github/workflows/lint-and-test.yaml @@ -0,0 +1,55 @@ +name: Curator testing workflow + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test-integration: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.11' + + - name: Load cached Poetry installation + uses: actions/cache@v3 + with: + path: ~/.local + key: poetry-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} + + - name: Load cached venv + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-python-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Set Locale + run: | + sudo locale-gen "en_US.UTF-8" + export LC_ALL=en_US.UTF-8 + export LANG=en_US.UTF-8 + export TELEMETRY_ENABLED=false + + - name: Install dependencies + run: | + pip install poetry + poetry install --with dev --extras "vllm code_execution" + + - name: Run ruff + run: | + poetry run ruff check . --output-format=github + poetry run ruff format . --check + + - name: Run tests with coverage + run: | + poetry run pytest --cov='bespokelabs' --cov-report=html --cov-fail-under=80 tests/ diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..00ad2e6 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,14 @@ +name: Pre-Commit + +on: [push, pull_request] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - run: pip install pre-commit + - run: pre-commit run --all-files \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6f580ac --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + # - repo: local + # hooks: + # - id: pytest + # name: Run pytest + # entry: poetry run pytest tests/ + # language: system + # types: [python] + # pass_filenames: false + # always_run: true + + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.8.6 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + types: [python] + # Run the formatter. + - id: ruff-format + args: [ --fix ] + #run even when no Python files are staged + #always_run: true + types: [python] From 89267edb80d825dd93a84f0afa4f5bef8d5b56d7 Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 15:55:18 -0700 Subject: [PATCH 4/7] cicd : add pre-commit and cicd --- .github/workflows/lint-and-test.yaml | 6 +- .github/workflows/pre-commit.yml | 24 +- poetry.lock | 427 ++++++++++++++++----------- pyproject.toml | 2 +- 4 files changed, 274 insertions(+), 185 deletions(-) diff --git a/.github/workflows/lint-and-test.yaml b/.github/workflows/lint-and-test.yaml index 395cabb..db0e764 100644 --- a/.github/workflows/lint-and-test.yaml +++ b/.github/workflows/lint-and-test.yaml @@ -1,4 +1,4 @@ -name: Curator testing workflow +name: Starfish testing workflow on: push: @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | pip install poetry - poetry install --with dev --extras "vllm code_execution" + poetry install --with dev" - name: Run ruff run: | @@ -52,4 +52,4 @@ jobs: - name: Run tests with coverage run: | - poetry run pytest --cov='bespokelabs' --cov-report=html --cov-fail-under=80 tests/ + poetry run pytest --cov='src' --cov-report=html --cov-fail-under=20 tests/ diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 00ad2e6..2995a09 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,14 +1,14 @@ -name: Pre-Commit +# name: Pre-Commit -on: [push, pull_request] +# on: [push, pull_request] -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: "3.x" - - run: pip install pre-commit - - run: pre-commit run --all-files \ No newline at end of file +# jobs: +# pre-commit: +# runs-on: ubuntu-latest +# steps: +# - uses: actions/checkout@v3 +# - uses: actions/setup-python@v4 +# with: +# python-version: "3.x" +# - run: pip install pre-commit +# - run: pre-commit run --all-files \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 5a542b9..dee4102 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7,7 +7,6 @@ description = "File support for asyncio." optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, @@ -20,7 +19,6 @@ description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"}, {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"}, @@ -33,7 +31,6 @@ description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb46bb0f24813e6cede6cc07b1961d4b04f331f7112a23b5e21f567da4ee50aa"}, {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:54eb3aead72a5c19fad07219acd882c1643a1027fbcdefac9b502c267242f955"}, @@ -121,6 +118,7 @@ files = [ [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" @@ -137,7 +135,6 @@ description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, @@ -153,7 +150,6 @@ description = "asyncio bridge to the standard sqlite3 module" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, @@ -173,7 +169,6 @@ description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -186,13 +181,13 @@ description = "High level compatibility layer for multiple asynchronous event lo optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c"}, {file = "anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028"}, ] [package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} @@ -209,7 +204,7 @@ description = "Disable App Nap on macOS >= 10.9" optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and platform_system == \"Darwin\"" +markers = "platform_system == \"Darwin\"" files = [ {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"}, {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, @@ -222,7 +217,6 @@ description = "Annotate AST trees with source code positions" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2"}, {file = "asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7"}, @@ -232,6 +226,19 @@ files = [ astroid = ["astroid (>=2,<4)"] test = ["astroid (>=2,<4)", "pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "async-timeout" +version = "5.0.1" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version <= \"3.10\" or platform_python_implementation == \"PyPy\" and python_version < \"3.11\"" +files = [ + {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, + {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, +] + [[package]] name = "attrs" version = "25.3.0" @@ -239,7 +246,6 @@ description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, @@ -260,7 +266,7 @@ description = "Backport of CPython tarfile module" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and python_version < \"3.12\"" +markers = "python_version < \"3.12\"" files = [ {file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"}, {file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"}, @@ -277,7 +283,6 @@ description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a"}, {file = "cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4"}, @@ -290,7 +295,6 @@ description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -303,7 +307,7 @@ description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "sys_platform == \"linux\" or implementation_name == \"pypy\" or platform_python_implementation == \"PyPy\"" +markers = "platform_python_implementation == \"PyPy\" or sys_platform == \"linux\" or implementation_name == \"pypy\"" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, @@ -384,7 +388,6 @@ description = "Validate configuration and produce human readable error messages. optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -397,7 +400,6 @@ description = "The Real First Universal Charset Detector. Open, modern and activ optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -499,8 +501,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["main", "dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -516,11 +517,11 @@ description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["main", "dev"] +markers = "platform_system == \"Windows\" or sys_platform == \"win32\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "platform_python_implementation == \"PyPy\" and (platform_system == \"Windows\" or sys_platform == \"win32\")", dev = "platform_python_implementation == \"PyPy\" and sys_platform == \"win32\""} [[package]] name = "comm" @@ -529,7 +530,6 @@ description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus- optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, @@ -548,7 +548,6 @@ description = "Code coverage measurement for Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "coverage-7.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2931f66991175369859b5fd58529cd4b73582461877ecfd859b6549869287ffe"}, {file = "coverage-7.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52a523153c568d2c0ef8826f6cc23031dc86cffb8c6aeab92c4ff776e7951b28"}, @@ -615,9 +614,63 @@ files = [ {file = "coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501"}, ] +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + [package.extras] toml = ["tomli ; python_full_version <= \"3.11.0a6\""] +[[package]] +name = "cryptography" +version = "43.0.3" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +markers = "(python_version < \"3.10\" or platform_python_implementation == \"PyPy\") and python_full_version < \"3.12.4\" and sys_platform == \"linux\"" +files = [ + {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73"}, + {file = "cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2"}, + {file = "cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd"}, + {file = "cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73"}, + {file = "cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995"}, + {file = "cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d03b5621a135bffecad2c73e9f4deb1a0f977b9a8ffe6f8e002bf6c9d07b918c"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a2a431ee15799d6db9fe80c82b055bae5a752bef645bba795e8e52687c69efe3"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:281c945d0e28c92ca5e5930664c1cefd85efe80e5c0d2bc58dd63383fda29f83"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f18c716be16bc1fea8e95def49edf46b82fccaa88587a45f8dc0ff6ab5d8e0a7"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a02ded6cd4f0a5562a8887df8b3bd14e822a90f97ac5e544c162899bc467664"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:53a583b6637ab4c4e3591a15bc9db855b8d9dee9a669b550f311480acab6eb08"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1ec0bcf7e17c0c5669d881b1cd38c4972fade441b27bda1051665faaa89bdcaa"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, + {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, +] + +[package.dependencies] +cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} + +[package.extras] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"] +docstest = ["pyenchant (>=1.6.11)", "readme-renderer", "sphinxcontrib-spelling (>=4.0.1)"] +nox = ["nox"] +pep8test = ["check-sdist", "click", "mypy", "ruff"] +sdist = ["build"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test-randomorder = ["pytest-randomly"] + [[package]] name = "cryptography" version = "44.0.2" @@ -625,7 +678,7 @@ description = "cryptography is a package which provides cryptographic recipes an optional = false python-versions = "!=3.9.0,!=3.9.1,>=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform == \"linux\"" +markers = "python_version >= \"3.10\" and sys_platform == \"linux\" and (python_full_version >= \"3.12.4\" or platform_python_implementation != \"PyPy\")" files = [ {file = "cryptography-44.0.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:efcfe97d1b3c79e486554efddeb8f6f53a4cdd4cf6086642784fa31fc384e1d7"}, {file = "cryptography-44.0.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29ecec49f3ba3f3849362854b7253a9f59799e3763b0c9d0826259a88efa02f1"}, @@ -684,7 +737,6 @@ description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "debugpy-1.8.14-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:93fee753097e85623cab1c0e6a68c76308cd9f13ffdf44127e6fab4fbf024339"}, {file = "debugpy-1.8.14-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d937d93ae4fa51cdc94d3e865f535f185d5f9748efb41d0d49e33bf3365bd79"}, @@ -721,7 +773,6 @@ description = "Decorators for Humans" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, @@ -734,7 +785,6 @@ description = "Distribution utilities" optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, @@ -747,7 +797,6 @@ description = "Distro - an OS platform information API" optional = false python-versions = ">=3.6" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, @@ -760,12 +809,27 @@ description = "Docutils -- Python Documentation Utilities" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2"}, {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, ] +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "python_version <= \"3.10\" or platform_python_implementation == \"PyPy\" and python_version < \"3.11\"" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "executing" version = "2.2.0" @@ -773,7 +837,6 @@ description = "Get the currently executing AST node of a frame, and other inform optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa"}, {file = "executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755"}, @@ -788,8 +851,7 @@ version = "0.115.12" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"}, {file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"}, @@ -811,7 +873,6 @@ description = "A platform independent file lock." optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"}, {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"}, @@ -829,7 +890,6 @@ description = "A list-like structure which implements collections.abc.MutableSeq optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, @@ -932,7 +992,6 @@ description = "File-system specification" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"}, {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"}, @@ -973,7 +1032,6 @@ description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -986,7 +1044,6 @@ description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "httpcore-1.0.8-py3-none-any.whl", hash = "sha256:5254cf149bcb5f75e9d1b2b9f729ea4a4b883d1ad7379fc632b727cec23674be"}, {file = "httpcore-1.0.8.tar.gz", hash = "sha256:86e94505ed24ea06514883fd44d2bc02d90e77e7979c8eb71b90f41d364a1bad"}, @@ -1009,7 +1066,6 @@ description = "The next generation HTTP client." optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, @@ -1035,7 +1091,6 @@ description = "Client library to download and publish models, datasets and other optional = false python-versions = ">=3.8.0" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"}, {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"}, @@ -1072,7 +1127,6 @@ description = "File identification library for Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "identify-2.6.9-py2.py3-none-any.whl", hash = "sha256:c98b4322da415a8e5a70ff6e51fbc2d2932c015532d77e9f8537b4ba7813b150"}, {file = "identify-2.6.9.tar.gz", hash = "sha256:d40dfe3142a1421d8518e3d3985ef5ac42890683e32306ad614a29490abeb6bf"}, @@ -1088,7 +1142,6 @@ description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -1104,7 +1157,6 @@ description = "Read metadata from Python packages" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e"}, {file = "importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580"}, @@ -1129,7 +1181,6 @@ description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -1142,7 +1193,6 @@ description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "ipykernel-6.29.5-py3-none-any.whl", hash = "sha256:afdb66ba5aa354b09b91379bac28ae4afebbb30e8b39510c9690afb7a10421b5"}, {file = "ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215"}, @@ -1170,6 +1220,85 @@ pyqt5 = ["pyqt5"] pyside6 = ["pyside6"] test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.23.5)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "ipython" +version = "8.18.1" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version < \"3.10\" or platform_python_implementation == \"PyPy\" and python_full_version < \"3.12.4\"" +files = [ + {file = "ipython-8.18.1-py3-none-any.whl", hash = "sha256:e8267419d72d81955ec1177f8a29aaa90ac80ad647499201119e2f05e99aa397"}, + {file = "ipython-8.18.1.tar.gz", hash = "sha256:ca6f079bb33457c66e233e4580ebfc4128855b4cf6370dddd73842a9563e8a27"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +prompt-toolkit = ">=3.0.41,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" +typing-extensions = {version = "*", markers = "python_version < \"3.10\""} + +[package.extras] +all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"] + +[[package]] +name = "ipython" +version = "8.35.0" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +markers = "platform_python_implementation != \"PyPy\" and python_version >= \"3.10\" and python_full_version < \"3.12.4\"" +files = [ + {file = "ipython-8.35.0-py3-none-any.whl", hash = "sha256:e6b7470468ba6f1f0a7b116bb688a3ece2f13e2f94138e508201fad677a788ba"}, + {file = "ipython-8.35.0.tar.gz", hash = "sha256:d200b7d93c3f5883fc36ab9ce28a18249c7706e51347681f80a0aef9895f2520"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\" and sys_platform != \"emscripten\""} +prompt_toolkit = ">=3.0.41,<3.1.0" +pygments = ">=2.4.0" +stack_data = "*" +traitlets = ">=5.13.0" +typing_extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} + +[package.extras] +all = ["ipython[black,doc,kernel,matplotlib,nbconvert,nbformat,notebook,parallel,qtconsole]", "ipython[test,test-extra]"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "intersphinx_registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli ; python_version < \"3.11\"", "typing_extensions"] +kernel = ["ipykernel"] +matplotlib = ["matplotlib"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"] +test-extra = ["curio", "ipython[test]", "jupyter_ai", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] + [[package]] name = "ipython" version = "9.1.0" @@ -1177,7 +1306,7 @@ description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.11" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" +markers = "python_full_version >= \"3.12.4\"" files = [ {file = "ipython-9.1.0-py3-none-any.whl", hash = "sha256:2df07257ec2f84a6b346b8d83100bcf8fa501c6e01ab75cd3799b0bb253b3d2a"}, {file = "ipython-9.1.0.tar.gz", hash = "sha256:a47e13a5e05e02f3b8e1e7a0f9db372199fe8c3763532fe7a1e0379e4e135f16"}, @@ -1194,7 +1323,6 @@ prompt_toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack_data = "*" traitlets = ">=5.13.0" -typing_extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} [package.extras] all = ["ipython[doc,matplotlib,test,test-extra]"] @@ -1211,7 +1339,7 @@ description = "Defines a variety of Pygments lexers for highlighting IPython cod optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" +markers = "python_full_version >= \"3.12.4\"" files = [ {file = "ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c"}, {file = "ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81"}, @@ -1227,7 +1355,6 @@ description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -1243,7 +1370,6 @@ description = "Utility functions for Python class constructs" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jaraco.classes-3.4.0-py3-none-any.whl", hash = "sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790"}, {file = "jaraco.classes-3.4.0.tar.gz", hash = "sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd"}, @@ -1263,7 +1389,6 @@ description = "Useful decorators and context managers" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jaraco.context-6.0.1-py3-none-any.whl", hash = "sha256:f797fc481b490edb305122c9181830a3a5b76d84ef6d1aef2fb9b47ab956f9e4"}, {file = "jaraco_context-6.0.1.tar.gz", hash = "sha256:9bae4ea555cf0b14938dc0aee7c9f32ed303aa20a3b73e7dc80111628792d1b3"}, @@ -1283,7 +1408,6 @@ description = "Functools like those found in stdlib" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jaraco.functools-4.1.0-py3-none-any.whl", hash = "sha256:ad159f13428bc4acbf5541ad6dec511f91573b90fba04df61dafa2a1231cf649"}, {file = "jaraco_functools-4.1.0.tar.gz", hash = "sha256:70f7e0e2ae076498e212562325e805204fc092d7b4c17e0e86c959e249701a9d"}, @@ -1307,7 +1431,6 @@ description = "An autocompletion tool for Python that can be used for text edito optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"}, {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"}, @@ -1328,7 +1451,7 @@ description = "Low-level, pure Python DBus protocol wrapper." optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform == \"linux\"" +markers = "sys_platform == \"linux\"" files = [ {file = "jeepney-0.9.0-py3-none-any.whl", hash = "sha256:97e5714520c16fc0a45695e5365a2e11b81ea79bba796e26f9f1d178cb182683"}, {file = "jeepney-0.9.0.tar.gz", hash = "sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732"}, @@ -1345,7 +1468,6 @@ description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, @@ -1364,7 +1486,6 @@ description = "Fast iterable JSON parser." optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jiter-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:816ec9b60fdfd1fec87da1d7ed46c66c44ffec37ab2ef7de5b147b2fce3fd5ad"}, {file = "jiter-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b1d3086f8a3ee0194ecf2008cf81286a5c3e540d977fa038ff23576c023c0ea"}, @@ -1451,7 +1572,6 @@ description = "Apply JSON-Patches (RFC 6902)" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, @@ -1467,7 +1587,6 @@ description = "Identify specific nodes in a JSON document (RFC 6901)" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, @@ -1480,7 +1599,6 @@ description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, @@ -1503,7 +1621,6 @@ description = "The JSON Schema meta-schemas and vocabularies, exposed as a Regis optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"}, @@ -1519,13 +1636,13 @@ description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f"}, {file = "jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419"}, ] [package.dependencies] +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" @@ -1543,7 +1660,6 @@ description = "Jupyter core package. A base package on which Jupyter projects re optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"}, {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"}, @@ -1565,7 +1681,6 @@ description = "Store and access your passwords safely." optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "keyring-25.6.0-py3-none-any.whl", hash = "sha256:552a3f7af126ece7ed5c89753650eec89c7eaae8617d0aa4d9ad2b75111266bd"}, {file = "keyring-25.6.0.tar.gz", hash = "sha256:0b39998aa941431eb3d9b0d4b2460bc773b9df6fed7621c2dfb291a7e0187a66"}, @@ -1591,15 +1706,14 @@ type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"] [[package]] name = "langchain-core" -version = "0.3.51" +version = "0.3.52" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ - {file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"}, - {file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"}, + {file = "langchain_core-0.3.52-py3-none-any.whl", hash = "sha256:cd137109c1e3d04f5a582c2cae9539b2cd5e4b795f486b58969dbc3d0387fe7c"}, + {file = "langchain_core-0.3.52.tar.gz", hash = "sha256:f1981ec9efa4fceb11ff5ca57f5f9c8e22859cea3a94f8a044e6de8815afbd57"}, ] [package.dependencies] @@ -1621,7 +1735,6 @@ description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = "<4.0,>=3.9.0" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "langgraph-0.3.30-py3-none-any.whl", hash = "sha256:879bd683248911e6a3c15a694256577c3335d68c1dce4ff5c7cc858fa5e9489a"}, {file = "langgraph-0.3.30.tar.gz", hash = "sha256:c1bc664072468d90cb27544cbc958117fca0c16bada20ff578817dacf63d941c"}, @@ -1641,7 +1754,6 @@ description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = "<4.0.0,>=3.9.0" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "langgraph_checkpoint-2.0.24-py3-none-any.whl", hash = "sha256:3836e2909ef2387d1fa8d04ee3e2a353f980d519fd6c649af352676dc73d66b8"}, {file = "langgraph_checkpoint-2.0.24.tar.gz", hash = "sha256:9596dad332344e7e871257be464df8a07c2e9bac66143081b11b9422b0167e5b"}, @@ -1658,7 +1770,6 @@ description = "Library with high-level APIs for creating and executing LangGraph optional = false python-versions = "<4.0.0,>=3.9.0" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "langgraph_prebuilt-0.1.8-py3-none-any.whl", hash = "sha256:ae97b828ae00be2cefec503423aa782e1bff165e9b94592e224da132f2526968"}, {file = "langgraph_prebuilt-0.1.8.tar.gz", hash = "sha256:4de7659151829b2b955b6798df6800e580e617782c15c2c5b29b139697491831"}, @@ -1675,7 +1786,6 @@ description = "SDK for interacting with LangGraph API" optional = false python-versions = "<4.0.0,>=3.9.0" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "langgraph_sdk-0.1.61-py3-none-any.whl", hash = "sha256:f2d774b12497c428862993090622d51e0dbc3f53e0cee3d74a13c7495d835cc6"}, {file = "langgraph_sdk-0.1.61.tar.gz", hash = "sha256:87dd1f07ab82da8875ac343268ece8bf5414632017ebc9d1cef4b523962fd601"}, @@ -1687,15 +1797,14 @@ orjson = ">=3.10.1" [[package]] name = "langsmith" -version = "0.3.30" +version = "0.3.31" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ - {file = "langsmith-0.3.30-py3-none-any.whl", hash = "sha256:80d591a4c62c14950ba497bb8b565ad9bd8d07e102b643916f0d2af1a7b2daaf"}, - {file = "langsmith-0.3.30.tar.gz", hash = "sha256:4588aad24623320cdf355f7594e583874c27e70460e6e6446a416ebb702b8cf7"}, + {file = "langsmith-0.3.31-py3-none-any.whl", hash = "sha256:ee780ae3eac69998c336817c0b9f5ccfecaaaa3e67d94b7ef726b58ab3e72a25"}, + {file = "langsmith-0.3.31.tar.gz", hash = "sha256:8d20bd08fa6c3bce54cb600ddc521cd218a1c3410f90d9266179bf83a7ff0897"}, ] [package.dependencies] @@ -1718,15 +1827,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"] [[package]] name = "litellm" -version = "1.66.0" +version = "1.66.1" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ - {file = "litellm-1.66.0-py3-none-any.whl", hash = "sha256:1e4a2a9e023b12a385e4283b2f7922f7bd9c15b93ee9e64e203333cf8ddbc895"}, - {file = "litellm-1.66.0.tar.gz", hash = "sha256:15f592bab604233083dc8b79e1e510e7e234f06525efe4c4255732bfc7ceb219"}, + {file = "litellm-1.66.1-py3-none-any.whl", hash = "sha256:1f601fea3f086c1d2d91be60b9db115082a2f3a697e4e0def72f8b9c777c7232"}, + {file = "litellm-1.66.1.tar.gz", hash = "sha256:98f7add913e5eae2131dd412ee27532d9a309defd9dbb64f6c6c42ea8a2af068"}, ] [package.dependencies] @@ -1753,7 +1861,6 @@ description = "Python logging made (stupidly) simple" optional = false python-versions = "<4.0,>=3.5" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, @@ -1772,8 +1879,7 @@ version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, @@ -1799,7 +1905,6 @@ description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, @@ -1871,7 +1976,6 @@ description = "Inline Matplotlib backend for Jupyter" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, @@ -1886,8 +1990,7 @@ version = "0.1.2" description = "Markdown URL utilities" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -1900,7 +2003,6 @@ description = "More routines for operating on iterables, beyond itertools" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "more-itertools-10.6.0.tar.gz", hash = "sha256:2cd7fad1009c31cc9fb6a035108509e6547547a7a738374f10bd49a09eb3ee3b"}, {file = "more_itertools-10.6.0-py3-none-any.whl", hash = "sha256:6eb054cb4b6db1473f6e15fcc676a08e4732548acd47c708f0e179c2c7c01e89"}, @@ -1913,7 +2015,6 @@ description = "multidict implementation" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32a998bd8a64ca48616eac5a8c1cc4fa38fb244a3facf2eeb14abe186e0f6cc5"}, {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a54ec568f1fc7f3c313c2f3b16e5db346bf3660e1309746e7fccbbfded856188"}, @@ -2021,6 +2122,9 @@ files = [ {file = "multidict-6.4.3.tar.gz", hash = "sha256:3ada0b058c9f213c5f95ba301f922d402ac234f1111a7d8fd70f1b99f3c281ec"}, ] +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -2028,7 +2132,6 @@ description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, @@ -2041,7 +2144,6 @@ description = "Python binding to Ammonia HTML sanitizer Rust crate" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "nh3-0.2.21-cp313-cp313t-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:fcff321bd60c6c5c9cb4ddf2554e22772bb41ebd93ad88171bbbb6f271255286"}, {file = "nh3-0.2.21-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31eedcd7d08b0eae28ba47f43fd33a653b4cdb271d64f1aeda47001618348fde"}, @@ -2076,7 +2178,6 @@ description = "Node.js virtual environment builder" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, @@ -2089,7 +2190,6 @@ description = "The official Python client for Ollama." optional = false python-versions = "<4.0,>=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "ollama-0.4.7-py3-none-any.whl", hash = "sha256:85505663cca67a83707be5fb3aeff0ea72e67846cea5985529d8eca4366564a1"}, {file = "ollama-0.4.7.tar.gz", hash = "sha256:891dcbe54f55397d82d289c459de0ea897e103b86a3f1fad0fdb1895922a75ff"}, @@ -2106,7 +2206,6 @@ description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a"}, {file = "openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526"}, @@ -2134,7 +2233,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, @@ -2213,7 +2311,6 @@ description = "Fast, correct Python msgpack library supporting dataclasses, date optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "ormsgpack-1.9.1-cp310-cp310-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:f1f804fd9c0fd84213a6022c34172f82323b34afa7052a4af18797582cf56365"}, {file = "ormsgpack-1.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eab5cec99c46276b37071d570aab98603f3d0309b3818da3247eb64bb95e5cfc"}, @@ -2265,7 +2362,6 @@ description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -2278,7 +2374,6 @@ description = "A Python Parser" optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, @@ -2295,7 +2390,7 @@ description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"emscripten\"" +markers = "platform_python_implementation == \"PyPy\" and python_full_version < \"3.12.4\" and sys_platform != \"win32\" or sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version < \"3.10\" and sys_platform != \"win32\"" files = [ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, @@ -2311,7 +2406,6 @@ description = "Query metadata from sdists / bdists / installed packages." optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pkginfo-1.10.0-py3-none-any.whl", hash = "sha256:889a6da2ed7ffc58ab5b900d888ddce90bce912f2d2de1dc1c26f4cb9fe65097"}, {file = "pkginfo-1.10.0.tar.gz", hash = "sha256:5df73835398d10db79f8eecd5cd86b1f6d29317589ea70796994d49399af6297"}, @@ -2327,7 +2421,6 @@ description = "A small Python package for determining appropriate platform-speci optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "platformdirs-4.3.7-py3-none-any.whl", hash = "sha256:a03875334331946f13c549dbd8f4bac7a13a50a895a0eb1e8c6a8ace80d40a94"}, {file = "platformdirs-4.3.7.tar.gz", hash = "sha256:eb437d586b6a0986388f0d6f74aa0cde27b48d0e3d66843640bfb6bdcdb6e351"}, @@ -2345,7 +2438,6 @@ description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -2362,7 +2454,6 @@ description = "A framework for managing and maintaining multi-language pre-commi optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pre_commit-4.2.0-py2.py3-none-any.whl", hash = "sha256:a009ca7205f1eb497d10b845e52c838a98b6cdd2102a6c8e4540e94ee75c58bd"}, {file = "pre_commit-4.2.0.tar.gz", hash = "sha256:601283b9757afd87d40c4c4a9b2b5de9637a8ea02eaff7adc2d0fb4e04841146"}, @@ -2377,15 +2468,14 @@ virtualenv = ">=20.10.0" [[package]] name = "prompt-toolkit" -version = "3.0.50" +version = "3.0.51" description = "Library for building powerful interactive command lines in Python" optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ - {file = "prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198"}, - {file = "prompt_toolkit-3.0.50.tar.gz", hash = "sha256:544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab"}, + {file = "prompt_toolkit-3.0.51-py3-none-any.whl", hash = "sha256:52742911fde84e2d423e2f9a4cf1de7d7ac4e51958f648d9540e0fb8db077b07"}, + {file = "prompt_toolkit-3.0.51.tar.gz", hash = "sha256:931a162e3b27fc90c86f1b48bb1fb2c528c2761475e57c9c06de13311c7b54ed"}, ] [package.dependencies] @@ -2398,7 +2488,6 @@ description = "Accelerated property cache" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f27785888d2fdd918bc36de8b8739f2d6c791399552333721b58193f68ea3e98"}, {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4e89cde74154c7b5957f87a355bb9c8ec929c167b59c83d90654ea36aeb6180"}, @@ -2507,7 +2596,6 @@ description = "Cross-platform lib for process and system monitoring in Python. optional = false python-versions = ">=3.6" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25"}, {file = "psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da"}, @@ -2532,7 +2620,7 @@ description = "Run a subprocess in a pseudo terminal" optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"emscripten\"" +markers = "platform_python_implementation == \"PyPy\" and python_full_version < \"3.12.4\" and sys_platform != \"win32\" or sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version < \"3.10\" and sys_platform != \"win32\"" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -2545,7 +2633,6 @@ description = "Safely evaluate AST nodes without side effects" optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0"}, {file = "pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42"}, @@ -2561,7 +2648,7 @@ description = "C parser in Python" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "sys_platform == \"linux\" or implementation_name == \"pypy\" or platform_python_implementation == \"PyPy\"" +markers = "platform_python_implementation == \"PyPy\" or sys_platform == \"linux\" or implementation_name == \"pypy\"" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, @@ -2574,7 +2661,6 @@ description = "Data validation using Python type hints" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"}, {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"}, @@ -2597,7 +2683,6 @@ description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"}, {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"}, @@ -2709,8 +2794,7 @@ version = "2.19.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -2726,7 +2810,6 @@ description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, @@ -2734,9 +2817,11 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -2748,7 +2833,6 @@ description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, @@ -2768,7 +2852,6 @@ description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde"}, {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"}, @@ -2788,7 +2871,6 @@ description = "Manage dependencies of tests" optional = false python-versions = ">=3.4" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pytest-dependency-0.6.0.tar.gz", hash = "sha256:934b0e6a39d95995062c193f7eaeed8a8ffa06ff1bcef4b62b0dc74a708bacc1"}, ] @@ -2804,7 +2886,6 @@ description = "pytest plugin to abort hanging tests" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, @@ -2820,7 +2901,6 @@ description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -2836,7 +2916,6 @@ description = "Read key-value pairs from a .env file and set them as environment optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, @@ -2879,7 +2958,7 @@ description = "A (partial) reimplementation of pywin32 using ctypes/cffi" optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform == \"win32\"" +markers = "sys_platform == \"win32\"" files = [ {file = "pywin32-ctypes-0.2.3.tar.gz", hash = "sha256:d162dc04946d704503b2edc4d55f3dba5c1d539ead017afa00142c38b9885755"}, {file = "pywin32_ctypes-0.2.3-py3-none-any.whl", hash = "sha256:8a1513379d709975552d202d942d9837758905c8d01eb82b8bcc30918929e7b8"}, @@ -2892,7 +2971,6 @@ description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -2956,7 +3034,6 @@ description = "Python bindings for 0MQ" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "pyzmq-26.4.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:0329bdf83e170ac133f44a233fc651f6ed66ef8e66693b5af7d54f45d1ef5918"}, {file = "pyzmq-26.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:398a825d2dea96227cf6460ce0a174cf7657d6f6827807d4d1ae9d0f9ae64315"}, @@ -3063,7 +3140,6 @@ description = "readme_renderer is a library for rendering readme descriptions fo optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "readme_renderer-44.0-py3-none-any.whl", hash = "sha256:2fbca89b81a08526aadf1357a8c2ae889ec05fb03f5da67f9769c9a592166151"}, {file = "readme_renderer-44.0.tar.gz", hash = "sha256:8712034eabbfa6805cacf1402b4eeb2a73028f72d1166d6f5cb7f9c047c5d1e1"}, @@ -3084,7 +3160,6 @@ description = "JSON Referencing + Python" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"}, {file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"}, @@ -3102,7 +3177,6 @@ description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"}, {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"}, @@ -3207,7 +3281,6 @@ description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -3230,7 +3303,6 @@ description = "A utility belt for advanced users of python-requests" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, @@ -3246,7 +3318,6 @@ description = "Validating URI References per RFC 3986" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "rfc3986-2.0.0-py2.py3-none-any.whl", hash = "sha256:50b1502b60e289cb37883f3dfd34532b8873c7de9f49bb546641ce9cbd256ebd"}, {file = "rfc3986-2.0.0.tar.gz", hash = "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c"}, @@ -3261,8 +3332,7 @@ version = "14.0.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" -groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0"}, {file = "rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725"}, @@ -3271,6 +3341,7 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -3282,7 +3353,6 @@ description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "rpds_py-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:006f4342fe729a368c6df36578d7a348c7c716be1da0a1a0f86e3021f8e98724"}, {file = "rpds_py-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2d53747da70a4e4b17f559569d5f9506420966083a31c5fbd84e764461c4444b"}, @@ -3407,7 +3477,6 @@ description = "An extremely fast Python linter and code formatter, written in Ru optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "ruff-0.8.6-py3-none-linux_armv6l.whl", hash = "sha256:defed167955d42c68b407e8f2e6f56ba52520e790aba4ca707a9c88619e580e3"}, {file = "ruff-0.8.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:54799ca3d67ae5e0b7a7ac234baa657a9c1784b48ec954a094da7c206e0365b1"}, @@ -3436,7 +3505,7 @@ description = "Python bindings to FreeDesktop.org Secret Service API" optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform == \"linux\"" +markers = "sys_platform == \"linux\"" files = [ {file = "SecretStorage-3.3.3-py3-none-any.whl", hash = "sha256:f356e6628222568e3af06f2eba8df495efa13b3b63081dafd4f7d9a7b7bc9f99"}, {file = "SecretStorage-3.3.3.tar.gz", hash = "sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77"}, @@ -3453,7 +3522,6 @@ description = "Easily download, build, install, upgrade, and uninstall Python pa optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8"}, {file = "setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54"}, @@ -3474,8 +3542,7 @@ version = "1.5.4" description = "Tool to Detect Surrounding Shell" optional = false python-versions = ">=3.7" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, @@ -3488,7 +3555,6 @@ description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -3501,7 +3567,6 @@ description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -3514,7 +3579,6 @@ description = "Extract data from python stack frames and tracebacks for informat optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -3534,8 +3598,7 @@ version = "0.46.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.9" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35"}, {file = "starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5"}, @@ -3543,6 +3606,7 @@ files = [ [package.dependencies] anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] @@ -3554,7 +3618,6 @@ description = "Retry code until it succeeds" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"}, {file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"}, @@ -3571,7 +3634,6 @@ description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "tiktoken-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:586c16358138b96ea804c034b8acf3f5d3f0258bd2bc3b0227af4af5d622e382"}, {file = "tiktoken-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d9c59ccc528c6c5dd51820b3474402f69d9a9e1d656226848ad68a8d5b2e5108"}, @@ -3620,7 +3682,6 @@ description = "" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41"}, {file = "tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3"}, @@ -3647,6 +3708,49 @@ dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +[[package]] +name = "tomli" +version = "2.2.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_full_version <= \"3.11.0a6\"" +files = [ + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, +] + [[package]] name = "tornado" version = "6.4.2" @@ -3654,7 +3758,6 @@ description = "Tornado is a Python web framework and asynchronous networking lib optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, @@ -3676,7 +3779,6 @@ description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, @@ -3699,7 +3801,6 @@ description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, @@ -3716,7 +3817,6 @@ description = "Collection of utilities for publishing packages on PyPI" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "twine-5.1.1-py3-none-any.whl", hash = "sha256:215dbe7b4b94c2c50a7315c0275d2258399280fbb7d04182c7e55e24b5f93997"}, {file = "twine-5.1.1.tar.gz", hash = "sha256:9aa0825139c02b3434d913545c7b847a21c835e11597f5255842d457da2322db"}, @@ -3739,8 +3839,7 @@ version = "0.15.2" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "typer-0.15.2-py3-none-any.whl", hash = "sha256:46a499c6107d645a9c13f7ee46c5d5096cae6f5fc57dd11eccbbb9ae3e44ddfc"}, {file = "typer-0.15.2.tar.gz", hash = "sha256:ab2fab47533a813c49fe1f16b1a370fd5819099c00b119e0633df65f22144ba5"}, @@ -3759,7 +3858,6 @@ description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, @@ -3772,7 +3870,6 @@ description = "Runtime typing introspection tools" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"}, {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"}, @@ -3788,7 +3885,7 @@ description = "HTTP library with thread-safe connection pooling, file post, and optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" +markers = "python_version < \"3.10\" or platform_python_implementation == \"PyPy\"" files = [ {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, @@ -3806,7 +3903,7 @@ description = "HTTP library with thread-safe connection pooling, file post, and optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\" and python_version >= \"3.10\"" files = [ {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"}, {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"}, @@ -3824,8 +3921,7 @@ version = "0.34.1" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.9" -groups = ["main"] -markers = "platform_python_implementation == \"PyPy\"" +groups = ["dev"] files = [ {file = "uvicorn-0.34.1-py3-none-any.whl", hash = "sha256:984c3a8c7ca18ebaad15995ee7401179212c59521e67bfc390c07fa2b8d2e065"}, {file = "uvicorn-0.34.1.tar.gz", hash = "sha256:af981725fc4b7ffc5cb3b0e9eda6258a90c4b52cb2a83ce567ae0a7ae1757afc"}, @@ -3834,6 +3930,7 @@ files = [ [package.dependencies] click = ">=7.0" h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] @@ -3845,7 +3942,6 @@ description = "Automatically mock your HTTP interactions to simplify and speed u optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "vcrpy-7.0.0-py2.py3-none-any.whl", hash = "sha256:55791e26c18daa363435054d8b35bd41a4ac441b6676167635d1b37a71dbe124"}, {file = "vcrpy-7.0.0.tar.gz", hash = "sha256:176391ad0425edde1680c5b20738ea3dc7fb942520a48d2993448050986b3a50"}, @@ -3854,7 +3950,7 @@ files = [ [package.dependencies] PyYAML = "*" urllib3 = [ - {version = "<2", markers = "platform_python_implementation == \"PyPy\""}, + {version = "<2", markers = "python_version < \"3.10\" or platform_python_implementation == \"PyPy\""}, {version = "*", markers = "platform_python_implementation != \"PyPy\" and python_version >= \"3.10\""}, ] wrapt = "*" @@ -3870,7 +3966,6 @@ description = "Virtual Python Environment builder" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "virtualenv-20.30.0-py3-none-any.whl", hash = "sha256:e34302959180fca3af42d1800df014b35019490b119eba981af27f2fa486e5d6"}, {file = "virtualenv-20.30.0.tar.gz", hash = "sha256:800863162bcaa5450a6e4d721049730e7f2dae07720e0902b0e4040bd6f9ada8"}, @@ -3892,7 +3987,6 @@ description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, @@ -3905,7 +3999,7 @@ description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" groups = ["main"] -markers = "platform_python_implementation == \"PyPy\" and sys_platform == \"win32\"" +markers = "sys_platform == \"win32\"" files = [ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, @@ -3921,7 +4015,6 @@ description = "Module for decorators, wrappers and monkey patching." optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"}, {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"}, @@ -4011,7 +4104,6 @@ description = "Python binding for xxHash" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212"}, {file = "xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520"}, @@ -4145,7 +4237,6 @@ description = "Yet another URL library" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0bae32f8ebd35c04d6528cedb4a26b8bf25339d3616b04613b97347f919b76d3"}, {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8015a076daf77823e7ebdcba474156587391dab4e70c732822960368c01251e6"}, @@ -4248,7 +4339,6 @@ description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" groups = ["main", "dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, @@ -4269,7 +4359,6 @@ description = "Zstandard bindings for Python" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "platform_python_implementation == \"PyPy\"" files = [ {file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"}, {file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"}, @@ -4378,5 +4467,5 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" -python-versions = ">=3.11,<4.0" -content-hash = "943286c51e5a94f47216d4965e1795d268995daa0f8b83ebfba48ef763d757b5" +python-versions = ">=3.9,<4.0" +content-hash = "8f4a7ee1d86e9fd480a85e4121ca42a544dce54803131b1e9d88a23794c65ff2" diff --git a/pyproject.toml b/pyproject.toml index d363b53..cf0c11a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ ] [tool.poetry.dependencies] -python = ">=3.8,<4.0" +python = ">=3.9,<4.0" litellm = ">=1.65.1,<2.0.0" loguru = ">=0.7.3,<0.8.0" cachetools = ">=5.5.2,<6.0.0" From 4d719e9121d8c5af25529ea1a3c82d9f698f60d1 Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 15:57:07 -0700 Subject: [PATCH 5/7] cicd : add pre-commit and cicd --- .github/workflows/lint-and-test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint-and-test.yaml b/.github/workflows/lint-and-test.yaml index db0e764..9f8f908 100644 --- a/.github/workflows/lint-and-test.yaml +++ b/.github/workflows/lint-and-test.yaml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | pip install poetry - poetry install --with dev" + poetry install --with dev - name: Run ruff run: | From 20eabe46c6d77248c5f941285e73e779512bc43f Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 16:19:21 -0700 Subject: [PATCH 6/7] cicd : add pre-commit and cicd --- .github/workflows/lint-and-test.yaml | 8 +- .github/workflows/pre-commit.yml | 14 -- .pre-commit-config.yaml | 8 +- Makefile | 8 +- examples/datafactor_input_examples.py | 20 +- examples/datafactory_test.ipynb | 156 +++++++------- examples/generate_data_with_topic.ipynb | 44 ++-- examples/llm_builder_workflow.py | 2 +- examples/simple_feedback_loop.py | 110 ++++------ examples/structured_llm.ipynb | 56 ++--- examples/test.py | 32 +-- examples/test_langgraph.py | 87 ++++---- examples/test_langgraph_structured_llm.py | 111 +++++----- examples/trial_llm.py | 70 +++--- pyproject.toml | 4 + src/starfish/__init__.py | 10 +- src/starfish/common/exceptions.py | 11 +- src/starfish/components/prepare_topic.py | 155 ++++++-------- src/starfish/data_factory/config.py | 2 +- src/starfish/data_factory/constants.py | 16 +- src/starfish/data_factory/event_loop.py | 1 - src/starfish/data_factory/factory.py | 199 +++++++++--------- src/starfish/data_factory/job_manager.py | 160 +++++++------- src/starfish/data_factory/state.py | 17 +- src/starfish/data_factory/storage/base.py | 2 +- .../storage/in_memory/in_memory_storage.py | 6 +- .../storage/local/data_handler.py | 2 - .../storage/local/local_storage.py | 66 +++--- .../storage/local/metadata_handler.py | 12 +- src/starfish/data_factory/storage/models.py | 4 +- src/starfish/data_factory/task_runner.py | 13 +- src/starfish/data_factory/utils/decorator.py | 21 +- src/starfish/data_factory/utils/enums.py | 3 +- src/starfish/data_factory/utils/errors.py | 7 +- src/starfish/data_factory/utils/mock.py | 9 +- src/starfish/llm/backend/ollama_adapter.py | 20 +- .../llm/model_hub/huggingface_adapter.py | 32 +-- src/starfish/llm/parser/json_builder.py | 4 +- src/starfish/llm/prompt/prompt_loader.py | 8 +- src/starfish/llm/prompt/prompt_template.py | 18 +- src/starfish/llm/proxy/litellm_adapter.py | 6 +- src/starfish/llm/proxy/litellm_adapter_ext.py | 2 +- src/starfish/llm/structured_llm.py | 3 +- tests/__init__.py | 5 +- .../storage/local/test_basic_storage.py | 10 +- .../storage/local/test_local_storage.py | 40 ++-- .../storage/local/test_performance.py | 23 +- .../data_factory/storage/test_storage_main.py | 21 +- tests/data_factory/test_data_factory.py | 131 +++++++----- tests/llm/prompt/test_prompt.py | 185 ++++++++-------- 50 files changed, 980 insertions(+), 974 deletions(-) delete mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/lint-and-test.yaml b/.github/workflows/lint-and-test.yaml index 9f8f908..983d0da 100644 --- a/.github/workflows/lint-and-test.yaml +++ b/.github/workflows/lint-and-test.yaml @@ -45,10 +45,10 @@ jobs: pip install poetry poetry install --with dev - - name: Run ruff - run: | - poetry run ruff check . --output-format=github - poetry run ruff format . --check + # - name: Run ruff + # run: | + # poetry run ruff check . --output-format=github + # poetry run ruff format . --check - name: Run tests with coverage run: | diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index 2995a09..0000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,14 +0,0 @@ -# name: Pre-Commit - -# on: [push, pull_request] - -# jobs: -# pre-commit: -# runs-on: ubuntu-latest -# steps: -# - uses: actions/checkout@v3 -# - uses: actions/setup-python@v4 -# with: -# python-version: "3.x" -# - run: pip install pre-commit -# - run: pre-commit run --all-files \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6f580ac..286f684 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,12 +14,12 @@ repos: rev: v0.8.6 hooks: # Run the linter. - - id: ruff - args: [ --fix ] - types: [python] + # - id: ruff + # args: [ --fix ] + # types: [python] # Run the formatter. - id: ruff-format - args: [ --fix ] + # args: [ --fix ] #run even when no Python files are staged #always_run: true types: [python] diff --git a/Makefile b/Makefile index 994f1eb..d9f19f3 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ lint: @echo "Running Linter (Ruff)..." - isort tests/ starfish/ examples -# poetry run ruff check tests starfish examples --fix - poetry run ruff format tests starfish examples + poetry run isort tests/ src/ examples --check-only || poetry run isort tests/ src/ examples + poetry run ruff check src examples --fix --unsafe-fixes --exit-zero + poetry run ruff format src examples --check || poetry run ruff format src examples test: poetry run pytest tests/ @@ -10,6 +10,6 @@ test: install: @echo "Installing dependencies..." poetry install - poetry run pre-commit install + poetry run pre-commit install --install-hooks diff --git a/examples/datafactor_input_examples.py b/examples/datafactor_input_examples.py index 6d14d6a..fb5101b 100644 --- a/examples/datafactor_input_examples.py +++ b/examples/datafactor_input_examples.py @@ -33,17 +33,11 @@ def workflow(city_name, num_records_per_city): data = workflow.run(city_name=["San Francisco", "New York", "Los Angeles"], num_records_per_city=3) -### Use Case 3: data=List[Dict] Only -data = workflow.run(data = [ - {'city_name': 'Paris'}, - {'city_name': 'Tokyo'} -]) +### Use Case 3: data=List[Dict] Only +data = workflow.run(data=[{"city_name": "Paris"}, {"city_name": "Tokyo"}]) ### Use Case 4: data=List[Dict] + Broadcast Kwarg -data = workflow.run(data = [ - {'city_name': 'Paris'}, - {'city_name': 'Tokyo'} -], num_records_per_city = 3) +data = workflow.run(data=[{"city_name": "Paris"}, {"city_name": "Tokyo"}], num_records_per_city=3) ### Use Case 5: data (List[Dict]) + Parallel Kwarg (Matching Lengths) @@ -95,22 +89,22 @@ def get_city_info_wf(city_name, region_code, num_records_per_city): ## Invoke sequence # [ -# {'city_name': 'Berlin'}, +# {'city_name': 'Berlin'}, # {'city_name': 'Rome'} # ] # [ -# {'city_name': 'Berlin', 'region_code': 'DE'}, +# {'city_name': 'Berlin', 'region_code': 'DE'}, # {'city_name': 'Rome', 'region_code': 'IT'} # ] # [ -# {'city_name': 'Beijing', 'region_code': 'DE'}, +# {'city_name': 'Beijing', 'region_code': 'DE'}, # {'city_name': 'Beijing', 'region_code': 'IT'} # ] # [ -# {'city_name': 'Beijing', 'region_code': 'DE', 'num_records_per_city': 3}, +# {'city_name': 'Beijing', 'region_code': 'DE', 'num_records_per_city': 3}, # {'city_name': 'Beijing', 'region_code': 'IT', 'num_records_per_city': 3} # ] diff --git a/examples/datafactory_test.ipynb b/examples/datafactory_test.ipynb index 271da6e..daa66a3 100644 --- a/examples/datafactory_test.ipynb +++ b/examples/datafactory_test.ipynb @@ -14,6 +14,7 @@ "outputs": [], "source": [ "import nest_asyncio\n", + "\n", "nest_asyncio.apply()" ] }, @@ -41,9 +42,9 @@ } ], "source": [ - "from starfish import StructuredLLM, data_factory\n", + "from starfish import data_factory\n", "from starfish.common.env_loader import load_env_file\n", - "from starfish.llm.utils import merge_structured_outputs\n", + "\n", "load_env_file()" ] }, @@ -54,8 +55,9 @@ "outputs": [], "source": [ "### Mock LLM call\n", - "import random\n", "import asyncio\n", + "import random\n", + "\n", "\n", "async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.05, sleep_time=0.01):\n", " # Simulate a slight delay (optional, feels more async-realistic)\n", @@ -63,10 +65,10 @@ "\n", " # 5% chance of failure\n", " if random.random() < fail_rate:\n", - " print(f\" {city_name}: Failed!\") ## For debugging\n", + " print(f\" {city_name}: Failed!\") ## For debugging\n", " raise ValueError(f\"Mock LLM failed to process city: {city_name}\")\n", - " \n", - " print(f\"{city_name}: Successfully processed!\") ## For debugging\n", + "\n", + " print(f\"{city_name}: Successfully processed!\") ## For debugging\n", "\n", " result = [f\"{city_name}_{random.randint(1, 5)}\" for _ in range(num_records_per_city)]\n", " return result" @@ -183,18 +185,15 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=5)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.5, sleep_time = 1):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", - "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - " {'city_name': '3. Chicago'},\n", - " {'city_name': '4. Houston'},\n", - " {'city_name': '5. Miami'}\n", - "], num_records_per_city=5)" + "async def test1(city_name, num_records_per_city, fail_rate=0.5, sleep_time=1):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", + "\n", + "test1.run(\n", + " data=[{\"city_name\": \"1. New York\"}, {\"city_name\": \"2. Los Angeles\"}, {\"city_name\": \"3. Chicago\"}, {\"city_name\": \"4. Houston\"}, {\"city_name\": \"5. Miami\"}],\n", + " num_records_per_city=5,\n", + ")" ] }, { @@ -275,12 +274,12 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.5, sleep_time = 1):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", + "async def test1(city_name, num_records_per_city, fail_rate=0.5, sleep_time=1):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", "\n", - "test1.run(city_name = [\"1. New York\", \"2. Los Angeles\", \"3. Chicago\", \"4. Houston\", \"5. Miami\"], num_records_per_city=5)" + "test1.run(city_name=[\"1. New York\", \"2. Los Angeles\", \"3. Chicago\", \"4. Houston\", \"5. Miami\"], num_records_per_city=5)" ] }, { @@ -549,18 +548,15 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 1, sleep_time = 0.05):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", - "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - " {'city_name': '3. Chicago'},\n", - " {'city_name': '4. Houston'},\n", - " {'city_name': '5. Miami'}\n", - "], num_records_per_city=5)" + "async def test1(city_name, num_records_per_city, fail_rate=1, sleep_time=0.05):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", + "\n", + "test1.run(\n", + " data=[{\"city_name\": \"1. New York\"}, {\"city_name\": \"2. Los Angeles\"}, {\"city_name\": \"3. Chicago\"}, {\"city_name\": \"4. Houston\"}, {\"city_name\": \"5. Miami\"}],\n", + " num_records_per_city=5,\n", + ")" ] }, { @@ -614,16 +610,19 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.1, sleep_time = 0.05):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", - "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - "], city_name = 'override_city_name', \n", - "num_records_per_city = 1)" + "async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", + "\n", + "test1.run(\n", + " data=[\n", + " {\"city_name\": \"1. New York\"},\n", + " {\"city_name\": \"2. Los Angeles\"},\n", + " ],\n", + " city_name=\"override_city_name\",\n", + " num_records_per_city=1,\n", + ")" ] }, { @@ -681,14 +680,18 @@ ], "source": [ "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.1, sleep_time = 0.05):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", - "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - "], city_name = ['1. override_city_name', '2. override_city_name'], \n", - "num_records_per_city = 1)" + "async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", + "\n", + "test1.run(\n", + " data=[\n", + " {\"city_name\": \"1. New York\"},\n", + " {\"city_name\": \"2. Los Angeles\"},\n", + " ],\n", + " city_name=[\"1. override_city_name\", \"2. override_city_name\"],\n", + " num_records_per_city=1,\n", + ")" ] }, { @@ -736,15 +739,17 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.1, sleep_time = 0.05):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", + "async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - "], city_name = 'override_city_name', \n", + "\n", + "test1.run(\n", + " data=[\n", + " {\"city_name\": \"1. New York\"},\n", + " {\"city_name\": \"2. Los Angeles\"},\n", + " ],\n", + " city_name=\"override_city_name\",\n", ")" ] }, @@ -793,15 +798,19 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=2)\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.1, sleep_time = 0.05):\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", + "async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05):\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - "], num_records_per_city = 1, random_param = 'random_param')" + "\n", + "test1.run(\n", + " data=[\n", + " {\"city_name\": \"1. New York\"},\n", + " {\"city_name\": \"2. Los Angeles\"},\n", + " ],\n", + " num_records_per_city=1,\n", + " random_param=\"random_param\",\n", + ")" ] }, { @@ -907,21 +916,24 @@ } ], "source": [ - "\n", "def test_hook(data, SharedState):\n", - " SharedState['variable'] = f'changed_state - {data}'\n", + " SharedState[\"variable\"] = f\"changed_state - {data}\"\n", " return SharedState\n", "\n", "\n", - "@data_factory(max_concurrency=2, on_record_complete=[test_hook], initial_state_values = {'variable': 'initial_state'})\n", - "async def test1(city_name, num_records_per_city, fail_rate = 0.1, sleep_time = 0.05):\n", + "@data_factory(max_concurrency=2, on_record_complete=[test_hook], initial_state_values={\"variable\": \"initial_state\"})\n", + "async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05):\n", " print(f\"Checking state: {test1.SharedState['variable']}\")\n", - " return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time)\n", + " return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time)\n", + "\n", "\n", - "test1.run(data = [\n", - " {'city_name': '1. New York'},\n", - " {'city_name': '2. Los Angeles'},\n", - "], num_records_per_city = 1)" + "test1.run(\n", + " data=[\n", + " {\"city_name\": \"1. New York\"},\n", + " {\"city_name\": \"2. Los Angeles\"},\n", + " ],\n", + " num_records_per_city=1,\n", + ")" ] }, { diff --git a/examples/generate_data_with_topic.ipynb b/examples/generate_data_with_topic.ipynb index 4828562..cc7677b 100644 --- a/examples/generate_data_with_topic.ipynb +++ b/examples/generate_data_with_topic.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "import nest_asyncio\n", + "\n", "nest_asyncio.apply()" ] }, @@ -38,13 +39,13 @@ ], "source": [ "facts_generator = StructuredLLM(\n", - " model_name=\"openai/gpt-4o-mini\",\n", - " prompt=\"\"\"Generate facts about {{city_name}} on {{topic}}\"\"\" , \n", - " output_schema=[{'name': 'question', 'type': 'str'}, {'name': 'answer', 'type': 'str'}],\n", - " model_kwargs={\"temperature\": 0.7}\n", - " )\n", + " model_name=\"openai/gpt-4o-mini\",\n", + " prompt=\"\"\"Generate facts about {{city_name}} on {{topic}}\"\"\",\n", + " output_schema=[{\"name\": \"question\", \"type\": \"str\"}, {\"name\": \"answer\", \"type\": \"str\"}],\n", + " model_kwargs={\"temperature\": 0.7},\n", + ")\n", "response = await facts_generator.run(city_name=\"San Francisco\", topic=\"history\")\n", - "response.data " + "response.data" ] }, { @@ -53,8 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "# data_market.run(name = 'generator_with_topic', \n", + "# data_market.run(name = 'generator_with_topic',\n", "# user_instructions = \"\"\"generate facts about san francisco\"\"\",\n", "# num_records = 100,\n", "# topics = [{'history': 10, 'culture': 10, 'food': 10}])" @@ -68,7 +68,7 @@ "source": [ "user_instructions = \"\"\"generate facts about san francisco\"\"\"\n", "num_records = 100\n", - "topics = [{'history': 40}, {'culture': 10}, {'food': 10}]" + "topics = [{\"history\": 40}, {\"culture\": 10}, {\"food\": 10}]" ] }, { @@ -77,9 +77,9 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "from starfish.components import prepare_topic\n", - "topic_list = await prepare_topic(num_records=num_records, topics = topics, user_instructions=user_instructions, records_per_topic=10)" + "\n", + "topic_list = await prepare_topic(num_records=num_records, topics=topics, user_instructions=user_instructions, records_per_topic=10)" ] }, { @@ -106,7 +106,8 @@ ], "source": [ "from collections import Counter\n", - "Counter([topic['topic'] for topic in topic_list])" + "\n", + "Counter([topic[\"topic\"] for topic in topic_list])" ] }, { @@ -202,20 +203,20 @@ } ], "source": [ - "\n", "@data_factory(max_concurrency=5)\n", "async def generate_facts(user_instructions: str, topic: str):\n", " print(f\"Generating facts for {topic}...\")\n", " facts_generator = StructuredLLM(\n", - " model_name=\"openai/gpt-4o-mini\",\n", - " prompt=\"\"\"{{user_instructions}} on {{topic}}\"\"\" , \n", - " output_schema=[{'name': 'question', 'type': 'str'}, {'name': 'answer', 'type': 'str'}],\n", - " model_kwargs={\"temperature\": 0.7}\n", - " )\n", + " model_name=\"openai/gpt-4o-mini\",\n", + " prompt=\"\"\"{{user_instructions}} on {{topic}}\"\"\",\n", + " output_schema=[{\"name\": \"question\", \"type\": \"str\"}, {\"name\": \"answer\", \"type\": \"str\"}],\n", + " model_kwargs={\"temperature\": 0.7},\n", + " )\n", " response = await facts_generator.run(user_instructions=user_instructions, topic=topic)\n", - " return response.data \n", + " return response.data\n", "\n", - "facts_data = generate_facts.run(data = topic_list, user_instructions=user_instructions)" + "\n", + "facts_data = generate_facts.run(data=topic_list, user_instructions=user_instructions)" ] }, { @@ -270,7 +271,7 @@ } ], "source": [ - "generate_facts.re_run(master_job_id = '82f03acc-275b-40f1-b106-4656f4ab48fd')" + "generate_facts.re_run(master_job_id=\"82f03acc-275b-40f1-b106-4656f4ab48fd\")" ] }, { @@ -297,7 +298,6 @@ } ], "source": [ - "\n", "len(facts_data)" ] }, diff --git a/examples/llm_builder_workflow.py b/examples/llm_builder_workflow.py index 620143f..ce91f62 100644 --- a/examples/llm_builder_workflow.py +++ b/examples/llm_builder_workflow.py @@ -21,7 +21,7 @@ @data_factory(storage="local", batch_size=5) def generate_city_info(cities: List[str], num_facts: int) -> List[Dict[str, Any]]: - """LLM-powered city fact generation pipeline with validation""" + """LLM-powered city fact generation pipeline with validation.""" results = [] for city in cities: diff --git a/examples/simple_feedback_loop.py b/examples/simple_feedback_loop.py index d9758e8..24ad8d8 100644 --- a/examples/simple_feedback_loop.py +++ b/examples/simple_feedback_loop.py @@ -1,6 +1,5 @@ -import asyncio -from typing import Dict, Optional, List from datetime import datetime +from typing import Dict, Optional from starfish import StructuredLLM, data_factory from starfish.data_factory.constants import RECORD_STATUS @@ -10,16 +9,16 @@ city_facts_llm = StructuredLLM( model_name="openai/gpt-4o-mini", prompt=""" - Generate comprehensive and interesting facts about {{city_name}}. + Generate comprehensive and interesting facts about {{city_name}}. Include historical information, famous landmarks, cultural significance, and notable events. - + {% if feedback %} Previous attempt received a score of {{score}}/10 with the following feedback: {{feedback}} - + Please address this feedback and improve your response. {% endif %} - + Make your response detailed, accurate, and engaging. """, output_schema=[ @@ -27,9 +26,9 @@ {"name": "landmarks", "type": "str"}, {"name": "cultural_significance", "type": "str"}, {"name": "notable_events", "type": "str"}, - {"name": "additional_info", "type": "str"} + {"name": "additional_info", "type": "str"}, ], - model_kwargs={"temperature": 0.7} + model_kwargs={"temperature": 0.7}, ) # Create a StructuredLLM instance for scoring and feedback @@ -37,45 +36,39 @@ model_name="openai/gpt-4o-mini", prompt=""" Evaluate the following city information about {{city_name}}: - + Historical Information: {{historical_info}} Famous Landmarks: {{landmarks}} Cultural Significance: {{cultural_significance}} Notable Events: {{notable_events}} Additional Information: {{additional_info}} - + Score this information from 1 to 10 based on: 1. Comprehensiveness (does it cover all important aspects?) 2. Accuracy (is the information correct?) 3. Engagement (is it interesting and well-written?) 4. Uniqueness (does it provide insights not commonly known?) - + Provide a score and detailed feedback for improvement if the score is less than 10. """, - output_schema=[ - {"name": "score", "type": "int"}, - {"name": "feedback", "type": "str"}, - {"name": "explanation", "type": "str"} - ], - model_kwargs={"temperature": 0.3} + output_schema=[{"name": "score", "type": "int"}, {"name": "feedback", "type": "str"}, {"name": "explanation", "type": "str"}], + model_kwargs={"temperature": 0.3}, ) + async def generate_city_facts(city_name: str, feedback: Optional[str] = None, score: Optional[int] = None) -> Dict: """Generate facts about a city.""" try: - response = await city_facts_llm.run( - city_name=city_name, - feedback=feedback, - score=score - ) - + response = await city_facts_llm.run(city_name=city_name, feedback=feedback, score=score) + if not response.data: raise ValueError("No data returned from LLM") - + return response.data[0] except Exception as e: raise Exception(f"Error generating city facts: {str(e)}") + async def evaluate_city_facts(city_name: str, facts: Dict) -> Dict: """Evaluate the generated facts.""" try: @@ -85,99 +78,84 @@ async def evaluate_city_facts(city_name: str, facts: Dict) -> Dict: landmarks=facts.get("landmarks", ""), cultural_significance=facts.get("cultural_significance", ""), notable_events=facts.get("notable_events", ""), - additional_info=facts.get("additional_info", "") + additional_info=facts.get("additional_info", ""), ) - + if not response.data: raise ValueError("No data returned from LLM") - + return response.data[0] except Exception as e: raise Exception(f"Error evaluating city facts: {str(e)}") + @data_factory(max_concurrency=5) async def process_city_facts(city_name: str, max_attempts: int = 5): """Process city facts with feedback loop.""" print(f"\nProcessing facts for {city_name}...") - + facts = None score = 0 feedback = None attempts = 0 - + while attempts < max_attempts: attempts += 1 print(f"\nAttempt {attempts}/{max_attempts}") - + try: # Generate facts facts = await generate_city_facts(city_name, feedback, score) print("\nGenerated Facts:") for key, value in facts.items(): print(f"{key}: {value}") - + # Evaluate facts evaluation = await evaluate_city_facts(city_name, facts) score = evaluation.get("score", 0) feedback = evaluation.get("feedback", "") explanation = evaluation.get("explanation", "") - - print(f"\nEvaluation:") + + print("\nEvaluation:") print(f"Score: {score}/10") print(f"Feedback: {feedback}") print(f"Explanation: {explanation}") - + # If we got a perfect score, we're done - if score >8: + if score > 8: print("\nReally good score!") break - + except Exception as e: print(f"Error: {str(e)}") - return { - RECORD_STATUS: RecordStatus.FAILED, - "error": str(e) - } - + return {RECORD_STATUS: RecordStatus.FAILED, "error": str(e)} + if not facts: - return { - RECORD_STATUS: RecordStatus.FAILED, - "error": "Failed to generate city facts" - } - - return { - RECORD_STATUS: RecordStatus.COMPLETED, - "output_ref": facts, - "final_score": score, - "attempts": attempts - } + return {RECORD_STATUS: RecordStatus.FAILED, "error": "Failed to generate city facts"} + + return {RECORD_STATUS: RecordStatus.COMPLETED, "output_ref": facts, "final_score": score, "attempts": attempts} + # Test the function if __name__ == "__main__": print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Starting simple feedback loop test") - + # List of cities to process - cities = [ - "San Francisco", - "New York", - "Tokyo", - "Paris", - "Sydney" - ] - + cities = ["San Francisco", "New York", "Tokyo", "Paris", "Sydney"] + results = process_city_facts.run(city_name=cities) - + # Print results for each city for city, result in zip(cities, results): print(f"\nResults for {city}:") print(f"Status: {result.get('status')}") - if result.get('status') == RecordStatus.COMPLETED: + if result.get("status") == RecordStatus.COMPLETED: print("\nCity Facts:") - for key, value in result['output_ref'].items(): + for key, value in result["output_ref"].items(): print(f"{key}: {value}") print(f"\nFinal Score: {result.get('final_score')}") print(f"Attempts: {result.get('attempts')}") else: print(f"Error: {result.get('error')}") - - print(f"\n{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished simple feedback loop test") \ No newline at end of file + + print(f"\n{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished simple feedback loop test") diff --git a/examples/structured_llm.ipynb b/examples/structured_llm.ipynb index 2df7f8d..566e470 100644 --- a/examples/structured_llm.ipynb +++ b/examples/structured_llm.ipynb @@ -14,6 +14,7 @@ "outputs": [], "source": [ "import nest_asyncio\n", + "\n", "nest_asyncio.apply()" ] }, @@ -162,9 +163,8 @@ "first_llm = StructuredLLM(\n", " model_name=\"hyperbolic/deepseek-ai/DeepSeek-V3-0324\",\n", " prompt=\"Facts about city {{city_name}}.\",\n", - " output_schema=[{'name': 'question', 'type': 'str'}, \n", - " {'name': 'answer', 'type': 'str'}],\n", - " model_kwargs={\"temperature\": 0.7}\n", + " output_schema=[{\"name\": \"question\", \"type\": \"str\"}, {\"name\": \"answer\", \"type\": \"str\"}],\n", + " model_kwargs={\"temperature\": 0.7},\n", ")\n", "\n", "first_response = await first_llm.run(city_name=\"New York\", num_records=5)\n", @@ -247,6 +247,7 @@ "source": [ "### Clean it up\n", "from starfish.core.llm.backend.ollama_adapter import stop_ollama_server\n", + "\n", "await stop_ollama_server()" ] }, @@ -304,8 +305,9 @@ } ], "source": [ + "from starfish import StructuredLLM\n", "from starfish.llm.utils import merge_structured_outputs\n", - "from starfish import StructuredLLM \n", + "\n", "first_llm = StructuredLLM(\n", " model_name=\"openai/gpt-4o-mini\",\n", " prompt=\"Facts about city {{city_name}}.\",\n", @@ -317,19 +319,24 @@ "\n", "second_llm = StructuredLLM(\n", " model_name=\"openai/gpt-4o-mini\",\n", - " prompt=\"\"\"You will be given a list of question and answer pairs, \n", - "please rate each individually about its accuracy, funny and conciseness. \n", - "rating are from 1 to 10, 1 being the worst and 10 being the best. \n", + " prompt=\"\"\"You will be given a list of question and answer pairs,\n", + "please rate each individually about its accuracy, funny and conciseness.\n", + "rating are from 1 to 10, 1 being the worst and 10 being the best.\n", "lets also rank them among themself so from 1 being the best.\n", "Here is question and answer pairs: {{QnA_pairs}}\"\"\",\n", - " output_schema=[{\"name\": \"accuracy\", \"type\": \"int\"}, {\"name\": \"funny\", \"type\": \"int\"}, {\"name\": \"conciseness\", \"type\": \"int\"}, {\"name\": \"rank\", \"type\": \"int\"}],\n", + " output_schema=[\n", + " {\"name\": \"accuracy\", \"type\": \"int\"},\n", + " {\"name\": \"funny\", \"type\": \"int\"},\n", + " {\"name\": \"conciseness\", \"type\": \"int\"},\n", + " {\"name\": \"rank\", \"type\": \"int\"},\n", + " ],\n", " model_kwargs={\"temperature\": 1},\n", ")\n", "\n", "second_response = await second_llm.run(QnA_pairs=first_response.data)\n", "\n", "### Merge result:\n", - "merge_structured_outputs(first_response.data, second_response.data)\n" + "merge_structured_outputs(first_response.data, second_response.data)" ] }, { @@ -494,7 +501,9 @@ } ], "source": [ - "from starfish.core.llm.utils import merge_structured_outputs \n", + "from starfish.core.llm.utils import merge_structured_outputs\n", + "\n", + "\n", "@data_factory(max_concurrency=50)\n", "async def workflow(city_name, num_records_per_city):\n", " print(f\"Processing city: {city_name}!\")\n", @@ -506,12 +515,11 @@ "\n", " first_response = await first_llm.run(city_name=city_name, num_records=num_records_per_city)\n", "\n", - "\n", " second_llm = StructuredLLM(\n", " model_name=\"openai/gpt-4o-mini\",\n", - " prompt=\"\"\"You will be given a question and answer pair, \n", - " please rate each individually about accuracy, funny and conciseness. \n", - " rating are from 1 to 10, 1 being the worst and 10 being the best. \n", + " prompt=\"\"\"You will be given a question and answer pair,\n", + " please rate each individually about accuracy, funny and conciseness.\n", + " rating are from 1 to 10, 1 being the worst and 10 being the best.\n", " Here is question and answer pair: {{QnA_pairs}}\"\"\",\n", " output_schema=[{\"name\": \"accuracy\", \"type\": \"int\"}, {\"name\": \"funny\", \"type\": \"int\"}, {\"name\": \"conciseness\", \"type\": \"int\"}],\n", " model_kwargs={\"temperature\": 1},\n", @@ -525,13 +533,10 @@ " return final_output\n", "\n", "\n", - "final_output = workflow.run(data = [\n", - " {'city_name': 'New York'},\n", - " {'city_name': 'Los Angeles'},\n", - " {'city_name': 'Chicago'},\n", - " {'city_name': 'Houston'},\n", - " {'city_name': 'Miami'}\n", - " ], num_records_per_city=5)" + "final_output = workflow.run(\n", + " data=[{\"city_name\": \"New York\"}, {\"city_name\": \"Los Angeles\"}, {\"city_name\": \"Chicago\"}, {\"city_name\": \"Houston\"}, {\"city_name\": \"Miami\"}],\n", + " num_records_per_city=5,\n", + ")" ] }, { @@ -594,13 +599,10 @@ "\n", "\n", "# Execute with batch processing\n", - "results = get_city_info_wf.run(\n", - " cities=[\"Paris\", \"Tokyo\", \"New York\", \"London\"],\n", - " num_facts=3\n", - ")\n", + "results = get_city_info_wf.run(cities=[\"Paris\", \"Tokyo\", \"New York\", \"London\"], num_facts=3)\n", "\n", "results = get_city_info_wf.run(\n", - " city_name= [\"Berlin\", \"Rome\"],\n", + " city_name=[\"Berlin\", \"Rome\"],\n", " region_code=[\"DE\", \"IT\"],\n", ")\n", "\n", @@ -610,7 +612,7 @@ " region_code=[\"DE\", \"IT\"],\n", " city_name=\"Beijing\", ### Overwrite the data key\n", " # num_records_per_city = 3\n", - ")\n" + ")" ] }, { diff --git a/examples/test.py b/examples/test.py index cb164e4..5cc31a4 100644 --- a/examples/test.py +++ b/examples/test.py @@ -1,14 +1,16 @@ -from starfish import StructuredLLM, data_factory -from starfish.common.env_loader import load_env_file -from starfish.llm.utils import merge_structured_outputs from datetime import datetime +from starfish import data_factory +from starfish.common.env_loader import load_env_file + load_env_file() +import asyncio + ### Mock LLM call import random -import asyncio + async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.05, sleep_time=0.01): # Simulate a slight delay (optional, feels more async-realistic) @@ -16,29 +18,27 @@ async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.05, sleep_t await asyncio.sleep(sleep_time) - print(f"{city_name}: Successfully processed!") ## For debugging + print(f"{city_name}: Successfully processed!") ## For debugging # 5% chance of failure if random.random() < fail_rate: - print(f" {city_name}: Failed!") ## For debugging + print(f" {city_name}: Failed!") ## For debugging raise ValueError(f"Mock LLM failed to process city: {city_name}") - + result = [f"{city_name}_{random.randint(1, 5)}" for _ in range(num_records_per_city)] return result @data_factory(max_concurrency=5) -async def test1(city_name, num_records_per_city, fail_rate = 0.05, sleep_time = 1): - return await mock_llm_call(city_name, num_records_per_city, fail_rate = fail_rate, sleep_time = sleep_time) +async def test1(city_name, num_records_per_city, fail_rate=0.05, sleep_time=1): + return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) + print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Starting test with max_concurrency=5") -data = test1.run(data = [ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - {'city_name': '3. Chicago'}, - {'city_name': '4. Houston'}, - {'city_name': '5. Miami'} -], num_records_per_city=5) +data = test1.run( + data=[{"city_name": "1. New York"}, {"city_name": "2. Los Angeles"}, {"city_name": "3. Chicago"}, {"city_name": "4. Houston"}, {"city_name": "5. Miami"}], + num_records_per_city=5, +) print(data) diff --git a/examples/test_langgraph.py b/examples/test_langgraph.py index 521f725..8af529d 100644 --- a/examples/test_langgraph.py +++ b/examples/test_langgraph.py @@ -1,15 +1,15 @@ -import asyncio -from typing import Dict, List, TypedDict, Annotated, Sequence from datetime import datetime +from typing import Annotated, Sequence, TypedDict -from langgraph.graph import Graph, StateGraph -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool +from langgraph.graph import StateGraph -from starfish import StructuredLLM, data_factory +from starfish import data_factory from starfish.data_factory.constants import RECORD_STATUS from starfish.data_factory.utils.enums import RecordStatus + # Define a simple tool @tool def get_weather(location: str) -> str: @@ -17,40 +17,37 @@ def get_weather(location: str) -> str: # Mock implementation return f"The weather in {location} is sunny with a temperature of 72°F." + # Define a simple tool @tool def get_population(city: str) -> str: """Get the population of a city.""" # Mock implementation - populations = { - "New York": "8.8 million", - "Los Angeles": "4 million", - "Chicago": "2.7 million", - "Houston": "2.3 million", - "Miami": "450,000" - } + populations = {"New York": "8.8 million", "Los Angeles": "4 million", "Chicago": "2.7 million", "Houston": "2.3 million", "Miami": "450,000"} return f"The population of {city} is {populations.get(city, 'unknown')}." + # Define the state type class AgentState(TypedDict): messages: Annotated[Sequence[HumanMessage | AIMessage], "The messages in the conversation"] next: Annotated[str, "The next node to run"] + # Create a simple agent that uses tools def create_agent(): # Define the agent function def agent(state: AgentState) -> AgentState: messages = state["messages"] - + # Get the last message last_message = messages[-1] - + # If it's a human message, process it if isinstance(last_message, HumanMessage): # Extract the query type and city name content = last_message.content.lower() city_name = None - + if "weather" in content: # Extract the city name city_name = content.split("weather in ")[-1].rstrip("?").strip() @@ -60,7 +57,7 @@ def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=weather)) # Set the next node to output return {"messages": messages, "next": "output"} - + elif "population" in content: # Extract the city name city_name = content.split("population in ")[-1].rstrip("?").strip() @@ -70,7 +67,7 @@ def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=population)) # Set the next node to output return {"messages": messages, "next": "output"} - + # If no specific request, use the LLM to generate a response else: # Create a simple LLM response @@ -79,74 +76,76 @@ def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=response)) # Set the next node to output return {"messages": messages, "next": "output"} - + # If it's not a human message, just end return {"messages": messages, "next": "output"} - + # Define the output function def output(state: AgentState) -> AgentState: # Just return the state as is return state - + # Create the graph workflow = StateGraph(AgentState) - + # Add the nodes workflow.add_node("agent", agent) workflow.add_node("output", output) - + # Set the entry point workflow.set_entry_point("agent") - + # Add edges workflow.add_edge("agent", "output") - + # Set the output node workflow.set_finish_point("output") - + # Compile the graph app = workflow.compile() - + return app + # Create a function that uses the LangGraph app async def process_city_query(city_name: str, query_type: str = "weather"): # Create the agent agent = create_agent() - + # Create the initial state - initial_state = { - "messages": [HumanMessage(content=f"What's the {query_type} in {city_name}?")], - "next": "agent" - } - + initial_state = {"messages": [HumanMessage(content=f"What's the {query_type} in {city_name}?")], "next": "agent"} + # Run the agent result = agent.invoke(initial_state) - + # Return the last message content return {RECORD_STATUS: RecordStatus.COMPLETED, "output_ref": result["messages"][-1].content} + # Wrap the function with the data factory decorator @data_factory(max_concurrency=5) async def process_cities(city_name, query_type="weather"): print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Processing {city_name} for {query_type}") return await process_city_query(city_name, query_type) + # Test the function if __name__ == "__main__": print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Starting test with max_concurrency=5") - + # Run the function with multiple cities - results = process_cities.run(data=[ - {'city_name': 'New York', 'query_type': 'weather'}, - {'city_name': 'Los Angeles', 'query_type': 'weather'}, - {'city_name': 'Chicago', 'query_type': 'population'}, - {'city_name': 'Houston', 'query_type': 'population'}, - {'city_name': 'Miami', 'query_type': 'weather'} - ]) - + results = process_cities.run( + data=[ + {"city_name": "New York", "query_type": "weather"}, + {"city_name": "Los Angeles", "query_type": "weather"}, + {"city_name": "Chicago", "query_type": "population"}, + {"city_name": "Houston", "query_type": "population"}, + {"city_name": "Miami", "query_type": "weather"}, + ] + ) + # Print the results for i, result in enumerate(results): print(f"Result {i+1}: {result}") - - print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished test") \ No newline at end of file + + print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished test") diff --git a/examples/test_langgraph_structured_llm.py b/examples/test_langgraph_structured_llm.py index 83c269f..6d8bbc3 100644 --- a/examples/test_langgraph_structured_llm.py +++ b/examples/test_langgraph_structured_llm.py @@ -1,15 +1,15 @@ -import asyncio -from typing import Dict, List, TypedDict, Annotated, Sequence from datetime import datetime +from typing import Annotated, Dict, Sequence, TypedDict -from langgraph.graph import Graph, StateGraph -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool +from langgraph.graph import StateGraph from starfish import StructuredLLM, data_factory from starfish.data_factory.constants import RECORD_STATUS from starfish.data_factory.utils.enums import RecordStatus + # Define a simple tool @tool def get_weather(location: str) -> str: @@ -17,54 +17,48 @@ def get_weather(location: str) -> str: # Mock implementation return f"The weather in {location} is sunny with a temperature of 72°F." + # Define a simple tool @tool def get_population(city: str) -> str: """Get the population of a city.""" # Mock implementation - populations = { - "New York": "8.8 million", - "Los Angeles": "4 million", - "Chicago": "2.7 million", - "Houston": "2.3 million", - "Miami": "450,000" - } + populations = {"New York": "8.8 million", "Los Angeles": "4 million", "Chicago": "2.7 million", "Houston": "2.3 million", "Miami": "450,000"} return f"The population of {city} is {populations.get(city, 'unknown')}." + # Define the state type class AgentState(TypedDict): messages: Annotated[Sequence[HumanMessage | AIMessage], "The messages in the conversation"] next: Annotated[str, "The next node to run"] city_info: Annotated[Dict, "Information about the city"] + # Create a StructuredLLM instance for city information city_info_llm = StructuredLLM( model_name="openai/gpt-4o-mini", prompt="Generate interesting facts about {{city_name}}. Include historical information, famous landmarks, and cultural significance.", - output_schema=[ - {"name": "historical_info", "type": "str"}, - {"name": "landmarks", "type": "str"}, - {"name": "cultural_significance", "type": "str"} - ], - model_kwargs={"temperature": 0.7} + output_schema=[{"name": "historical_info", "type": "str"}, {"name": "landmarks", "type": "str"}, {"name": "cultural_significance", "type": "str"}], + model_kwargs={"temperature": 0.7}, ) + # Create a simple agent that uses tools and StructuredLLM def create_agent(): # Define the agent function async def agent(state: AgentState) -> AgentState: messages = state["messages"] city_info = state.get("city_info", {}) - + # Get the last message last_message = messages[-1] - + # If it's a human message, process it if isinstance(last_message, HumanMessage): # Extract the query type and city name content = last_message.content.lower() city_name = None - + if "weather" in content: # Extract the city name city_name = content.split("weather in ")[-1].rstrip("?").strip() @@ -74,7 +68,7 @@ async def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=weather)) # Set the next node to output return {"messages": messages, "next": "output", "city_info": city_info} - + elif "population" in content: # Extract the city name city_name = content.split("population in ")[-1].rstrip("?").strip() @@ -84,7 +78,7 @@ async def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=population)) # Set the next node to output return {"messages": messages, "next": "output", "city_info": city_info} - + elif "facts" in content: # Extract the city name city_name = content.split("facts in ")[-1].rstrip("?").strip() @@ -95,25 +89,25 @@ async def agent(state: AgentState) -> AgentState: try: # Use the StructuredLLM to get city info city_info_response = await city_info_llm.run(city_name=city_name) - + # Handle the response based on its type response_data = city_info_response.data[0] - + city_info = { "city_name": city_name, "historical_info": response_data.get("historical_info", "No historical information available."), "landmarks": response_data.get("landmarks", "No landmarks information available."), - "cultural_significance": response_data.get("cultural_significance", "No cultural significance information available.") + "cultural_significance": response_data.get("cultural_significance", "No cultural significance information available."), } - except Exception as e: + except Exception: # If there's an error, provide default information city_info = { "city_name": city_name, "historical_info": "Error retrieving historical information.", "landmarks": "Error retrieving landmarks information.", - "cultural_significance": "Error retrieving cultural significance information." + "cultural_significance": "Error retrieving cultural significance information.", } - + # Create a response with the city info response = f"Here are some interesting facts about {city_name}:\n\n" response += f"Historical Information: {city_info.get('historical_info', '')}\n\n" @@ -123,7 +117,7 @@ async def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=response)) # Set the next node to output return {"messages": messages, "next": "output", "city_info": city_info} - + # If no specific request, use the LLM to generate a response else: # Create a simple LLM response @@ -132,83 +126,84 @@ async def agent(state: AgentState) -> AgentState: messages.append(AIMessage(content=response)) # Set the next node to output return {"messages": messages, "next": "output", "city_info": city_info} - + # If it's not a human message, just end return {"messages": messages, "next": "output", "city_info": city_info} - + # Define the output function def output(state: AgentState) -> AgentState: # Just return the state as is return state - + # Create the graph workflow = StateGraph(AgentState) - + # Add the nodes workflow.add_node("agent", agent) workflow.add_node("output", output) - + # Set the entry point workflow.set_entry_point("agent") - + # Add edges workflow.add_edge("agent", "output") - + # Set the output node workflow.set_finish_point("output") - + # Compile the graph app = workflow.compile() - + return app + # Create a function that uses the LangGraph app async def process_city_query(city_name: str, query_type: str = "weather"): # Create the agent agent = create_agent() - + # Create the initial state - initial_state = { - "messages": [HumanMessage(content=f"What's the {query_type} in {city_name}?")], - "next": "agent", - "city_info": {} - } - + initial_state = {"messages": [HumanMessage(content=f"What's the {query_type} in {city_name}?")], "next": "agent", "city_info": {}} + try: # Run the agent result = await agent.ainvoke(initial_state) - + # Get the last message content if isinstance(result, dict) and "messages" in result and result["messages"]: last_message = result["messages"][-1] if hasattr(last_message, "content"): return {RECORD_STATUS: RecordStatus.COMPLETED, "output_ref": last_message.content} - + return {RECORD_STATUS: RecordStatus.FAILED, "error": "Invalid response format from agent"} except Exception as e: return {RECORD_STATUS: RecordStatus.FAILED, "error": str(e)} + # Wrap the function with the data factory decorator @data_factory(max_concurrency=5) async def process_cities(city_name, query_type="weather"): print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Processing {city_name} for {query_type}") return await process_city_query(city_name, query_type) + # Test the function if __name__ == "__main__": print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Starting test with max_concurrency=5") - + # Run the function with multiple cities - results = process_cities.run(data=[ - {'city_name': 'New York', 'query_type': 'weather'}, - {'city_name': 'Los Angeles', 'query_type': 'facts'}, - {'city_name': 'Chicago', 'query_type': 'population'}, - {'city_name': 'Houston', 'query_type': 'facts'}, - {'city_name': 'Miami', 'query_type': 'weather'} - ]) - + results = process_cities.run( + data=[ + {"city_name": "New York", "query_type": "weather"}, + {"city_name": "Los Angeles", "query_type": "facts"}, + {"city_name": "Chicago", "query_type": "population"}, + {"city_name": "Houston", "query_type": "facts"}, + {"city_name": "Miami", "query_type": "weather"}, + ] + ) + # Print the results for i, result in enumerate(results): print(f"Result {i+1}: {result}") - - print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished test") \ No newline at end of file + + print(f"{datetime.now().strftime('%H:%M:%S.%f')[:-3]} - Finished test") diff --git a/examples/trial_llm.py b/examples/trial_llm.py index b916908..c4c8ce4 100644 --- a/examples/trial_llm.py +++ b/examples/trial_llm.py @@ -1,42 +1,53 @@ import random -import asyncio -from typing import Any, Dict -from starfish.llm.structured_llm import StructuredLLM +from typing import Any + +from starfish.common.logger import get_logger +from starfish.data_factory.constants import ( + STATUS_COMPLETED, + STATUS_DUPLICATE, + STATUS_FAILED, + STORAGE_TYPE_LOCAL, +) from starfish.data_factory.factory import data_factory -from starfish.data_factory.constants import STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED, STORAGE_TYPE_IN_MEMORY, STORAGE_TYPE_LOCAL from starfish.data_factory.state import MutableSharedState -from starfish.common.logger import get_logger from starfish.data_factory.utils.mock import mock_llm_call + logger = get_logger(__name__) + + # Add callback for error handling # todo state is a class with thread safe dict def handle_error(data: Any, state: MutableSharedState): logger.error(f"Error occurred: {data}") return STATUS_FAILED + def handle_record_complete(data: Any, state: MutableSharedState): - #print(f"Record complete: {data}") + # print(f"Record complete: {data}") - state.set("completed_count", 1) - state_data = state.data + state.set("completed_count", 1) state.update({"completed_count": 2}) return STATUS_COMPLETED + def handle_duplicate_record(data: Any, state: MutableSharedState): logger.debug(f"Record duplicated: {data}") - state.set("completed_count", 1) - state_data = state.data + state.set("completed_count", 1) state.update({"completed_count": 2}) - #return STATUS_DUPLICATE + # return STATUS_DUPLICATE if random.random() < 0.9: return STATUS_COMPLETED return STATUS_DUPLICATE - @data_factory( - storage=STORAGE_TYPE_LOCAL, max_concurrency=50, initial_state_values={}, on_record_complete=[handle_record_complete, handle_duplicate_record], - on_record_error=[handle_error],show_progress=True, task_runner_timeout=10 + storage=STORAGE_TYPE_LOCAL, + max_concurrency=50, + initial_state_values={}, + on_record_complete=[handle_record_complete, handle_duplicate_record], + on_record_error=[handle_error], + show_progress=True, + task_runner_timeout=10, ) async def get_city_info_wf(city_name, region_code): # structured_llm = StructuredLLM( @@ -62,7 +73,7 @@ async def get_city_info_wf(city_name, region_code): # ) # output = await validation_llm.run(data=output.data) - #return output.data + # return output.data return await mock_llm_call(city_name, num_records_per_city=3, fail_rate=0.01, sleep_time=1) @@ -76,24 +87,23 @@ async def get_city_info_wf(city_name, region_code): user_case = "run" if user_case == "run": results = get_city_info_wf.run( - #data=[{"city_name": "Berlin"}, {"city_name": "Rome"}], - #[{"city_name": "Berlin"}, {"city_name": "Rome"}], - city_name=["San Francisco", "New York", "Los Angeles"]*50, - region_code=["DE", "IT", "US"]*50, + # data=[{"city_name": "Berlin"}, {"city_name": "Rome"}], + # [{"city_name": "Berlin"}, {"city_name": "Rome"}], + city_name=["San Francisco", "New York", "Los Angeles"] * 50, + region_code=["DE", "IT", "US"] * 50, # city_name="Beijing", ### Overwrite the data key # num_records_per_city = 3 ) -elif user_case == "dry_run": +elif user_case == "dry_run": results = get_city_info_wf.dry_run( - #data=[{"city_name": "Berlin"}, {"city_name": "Rome"}], - #[{"city_name": "Berlin"}, {"city_name": "Rome"}], - city_name=["San Francisco", "New York", "Los Angeles"]*10, - region_code=["DE", "IT", "US"]*10, - # city_name="Beijing", ### Overwrite the data key - # num_records_per_city = 3 - ) + # data=[{"city_name": "Berlin"}, {"city_name": "Rome"}], + # [{"city_name": "Berlin"}, {"city_name": "Rome"}], + city_name=["San Francisco", "New York", "Los Angeles"] * 10, + region_code=["DE", "IT", "US"] * 10, + # city_name="Beijing", ### Overwrite the data key + # num_records_per_city = 3 + ) elif user_case == "re_run": - results = get_city_info_wf.re_run( master_job_id="05668e16-6f47-4ccf-9f25-4ff7b7030bdb") + results = get_city_info_wf.re_run(master_job_id="05668e16-6f47-4ccf-9f25-4ff7b7030bdb") -#logger.info(f"Results: {results}") - +# logger.info(f"Results: {results}") diff --git a/pyproject.toml b/pyproject.toml index cf0c11a..2690955 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,10 @@ starfish = "starfish.api.cli:main" [tool.ruff] line-length = 160 +# Auto-fix settings +fix = true +unsafe-fixes = true + [tool.ruff.lint] select = [ "E", # pycodestyle errors diff --git a/src/starfish/__init__.py b/src/starfish/__init__.py index b5c3891..ec6d4f4 100644 --- a/src/starfish/__init__.py +++ b/src/starfish/__init__.py @@ -1,7 +1,6 @@ # Expose core directly from easy access -from .llm.structured_llm import StructuredLLM from .data_factory.factory import data_factory - +from .llm.structured_llm import StructuredLLM # Define what 'from starfish import *' imports (good practice) __all__ = [ @@ -12,11 +11,12 @@ # You might also include the package version here # This is often automatically managed by build tools like setuptools_scm try: - from importlib.metadata import version, PackageNotFoundError + from importlib.metadata import PackageNotFoundError, version + try: - __version__ = version("starfish-core") # Updated to match our package name + __version__ = version("starfish-core") # Updated to match our package name except PackageNotFoundError: - # package is not installed + # package is not installed __version__ = "unknown" except ImportError: # Fallback for older Python versions diff --git a/src/starfish/common/exceptions.py b/src/starfish/common/exceptions.py index fe32097..c5ed058 100644 --- a/src/starfish/common/exceptions.py +++ b/src/starfish/common/exceptions.py @@ -18,8 +18,10 @@ # HTTP Status Codes ############################################# + class HTTPStatus: - """Standard HTTP status codes""" + """Standard HTTP status codes.""" + OK = 200 BAD_REQUEST = 400 UNAUTHORIZED = 401 @@ -28,13 +30,14 @@ class HTTPStatus: UNPROCESSABLE_ENTITY = 422 INTERNAL_SERVER_ERROR = 500 + ############################################# # Error Response Model ############################################# class ErrorResponse(BaseModel): - """Standardized error response format for API errors""" + """Standardized error response format for API errors.""" status: str = "error" error_id: str = Field(..., description="Unique identifier for this error occurrence") @@ -49,7 +52,7 @@ class ErrorResponse(BaseModel): class StarfishException(Exception): - """Base exception for all Starfish exceptions""" + """Base exception for all Starfish exceptions.""" status_code: int = HTTPStatus.INTERNAL_SERVER_ERROR default_message: str = "An unexpected error occurred" @@ -67,7 +70,7 @@ def __str__(self): class ValidationError(StarfishException): - """Exception raised for validation errors""" + """Exception raised for validation errors.""" status_code = HTTPStatus.UNPROCESSABLE_ENTITY default_message = "Validation error" diff --git a/src/starfish/components/prepare_topic.py b/src/starfish/components/prepare_topic.py index bed344a..354057c 100644 --- a/src/starfish/components/prepare_topic.py +++ b/src/starfish/components/prepare_topic.py @@ -1,57 +1,58 @@ -from typing import List, Dict, Union, Optional, Any, Tuple -from starfish import StructuredLLM import asyncio import math +from typing import Any, Dict, List, Optional, Union + +from starfish import StructuredLLM + async def generate_topics( user_instructions: str, num_topics: int, model_name: str = "openai/gpt-4o-mini", - model_kwargs: Optional[Dict[str, Any]] = {}, - existing_topics: Optional[List[str]] = None + model_kwargs: Optional[Dict[str, Any]] = None, + existing_topics: Optional[List[str]] = None, ) -> List[str]: """Generate unique topics based on user instructions using a StructuredLLM model.""" - if 'temperature' not in model_kwargs: - model_kwargs['temperature'] = 1 + if model_kwargs is None: + model_kwargs = {} + if "temperature" not in model_kwargs: + model_kwargs["temperature"] = 1 existing_topics = existing_topics or [] - + if num_topics <= 0: return [] - + # Calculate batches needed (5 topics per batch) llm_batch_size = 5 num_batches = math.ceil(num_topics / llm_batch_size) generated_topics = [] - + for _ in range(num_batches): topic_generator = StructuredLLM( model_name=model_name, prompt="""Can you generate a list of topics about {{user_instructions}} {% if existing_topics_str %} - Please do not generate topics that are already in the list: {{existing_topics_str}} + Please do not generate topics that are already in the list: {{existing_topics_str}} Make sure the topics are unique and vary from each other {% endif %} """, - output_schema=[{'name': 'topic', 'type': 'str'}], - model_kwargs=model_kwargs + output_schema=[{"name": "topic", "type": "str"}], + model_kwargs=model_kwargs, ) - + all_existing = existing_topics + generated_topics - input_params = { - 'user_instructions': user_instructions, - 'num_records': min(llm_batch_size, num_topics - len(generated_topics)) - } - + input_params = {"user_instructions": user_instructions, "num_records": min(llm_batch_size, num_topics - len(generated_topics))} + if all_existing: - input_params['existing_topics_str'] = ",".join(all_existing) + input_params["existing_topics_str"] = ",".join(all_existing) topic_response = await topic_generator.run(**input_params) - topic_data = [item.get('topic') for item in topic_response.data] + topic_data = [item.get("topic") for item in topic_response.data] generated_topics.extend(topic_data) - + if len(generated_topics) >= num_topics: break - + return generated_topics @@ -61,17 +62,16 @@ async def prepare_topic( records_per_topic: int = 20, user_instructions: Optional[str] = None, model_name: str = "openai/gpt-4o-mini", - model_kwargs: Optional[Dict[str, Any]] = {} + model_kwargs: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, str]]: - """ - Split records into topics, generating topics if none are provided or if needed. - + """Split records into topics, generating topics if none are provided or if needed. + Supported input formats: 1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution 2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts 3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats 4. None: No topics provided, will generate based on user_instructions - + Args: topics: Optional list of topics, either strings or {topic: count} dicts num_records: Total number of records to split (required for dict topics or None topics) @@ -79,12 +79,14 @@ async def prepare_topic( user_instructions: Topic generation instructions (required if topics is None) model_name: Model name for topic generation model_kwargs: Model kwargs for topic generation - + Returns: List of {'topic': topic_name} dictionaries, with one entry per record """ - if 'temperature' not in model_kwargs: - model_kwargs['temperature'] = 1 + if model_kwargs is None: + model_kwargs = {} + if "temperature" not in model_kwargs: + model_kwargs["temperature"] = 1 # --- STEP 1: Input validation and normalization --- if topics is None: # Must have num_records and user_instructions if no topics provided @@ -97,13 +99,13 @@ async def prepare_topic( # Validate topics is a non-empty list if not isinstance(topics, list) or not topics: raise ValueError("topics must be a non-empty list") - + # Convert all topic inputs to a standardized [(topic_name, count)] list # For string topics: count will be None (to be calculated later) # For dict topics: use the specified count topic_assignments = [] seen_topics = set() - + for topic in topics: if isinstance(topic, str): if topic not in seen_topics: @@ -112,99 +114,95 @@ async def prepare_topic( elif isinstance(topic, dict) and len(topic) == 1: topic_name = next(iter(topic)) count = topic[topic_name] - + if not isinstance(count, int) or count < 0: raise ValueError(f"Topic '{topic_name}' has invalid count {count}") - + if topic_name not in seen_topics: topic_assignments.append((topic_name, count)) seen_topics.add(topic_name) else: raise ValueError("Topics must be strings or single-key dictionaries") - + # --- STEP 2: Calculate or validate counts for provided topics --- result = [] assigned_count = 0 topic_names = [] # Track all assigned topic names - + if topic_assignments: # Handle string topics with no count (None) - assign counts based on input string_topics = [(name, count) for name, count in topic_assignments if count is None] dict_topics = [(name, count) for name, count in topic_assignments if count is not None] - + # Case: String topics with no num_records - assign records_per_topic to each if string_topics and num_records is None: for name, _ in string_topics: result.append({name: records_per_topic}) topic_names.append(name) assigned_count += records_per_topic - + # Case: String topics with num_records - distribute evenly elif string_topics and num_records is not None: remaining = num_records - sum(count for _, count in dict_topics if count is not None) if remaining < 0: raise ValueError("Dict topic counts exceed num_records") - + # Distribute remaining records among string topics if string_topics and remaining > 0: base = remaining // len(string_topics) extra = remaining % len(string_topics) - + for i, (name, _) in enumerate(string_topics): count = base + (1 if i < extra else 0) if count > 0: result.append({name: count}) topic_names.append(name) assigned_count += count - + # Add dictionary topics with predefined counts for name, count in dict_topics: if count > 0: result.append({name: count}) topic_names.append(name) assigned_count += count - + # Validate total count for dictionary topics if dict_topics and num_records is None: raise ValueError("num_records required when using dictionary topics") - + if num_records is not None and assigned_count > num_records: raise ValueError(f"Total assigned count ({assigned_count}) exceeds num_records ({num_records})") - + # --- STEP 3: Generate topics for remaining records if needed --- remaining_records = 0 if num_records is None else num_records - assigned_count - + if remaining_records > 0: if records_per_topic <= 0: raise ValueError("records_per_topic must be positive when generating topics") - + # Generate topics with LLM if instructions provided if user_instructions: topics_needed = math.ceil(remaining_records / records_per_topic) - + generated = await generate_topics( - user_instructions=user_instructions, - num_topics=topics_needed, - model_name=model_name, - model_kwargs=model_kwargs, - existing_topics=topic_names + user_instructions=user_instructions, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names ) - + # Assign counts to generated topics for topic in generated: if topic in topic_names: # Skip if duplicate (shouldn't happen with proper LLM) print(f"Skipping duplicate generated topic: {topic}") continue - + count = min(records_per_topic, remaining_records) if count <= 0: break - + result.append({topic: count}) topic_names.append(topic) remaining_records -= count assigned_count += count - + # Generate auto-topics for any still-remaining records auto_index = 1 while remaining_records > 0: @@ -213,22 +211,22 @@ async def prepare_topic( while auto_name in topic_names: auto_index += 1 auto_name = f"auto_topic{auto_index}" - + count = min(records_per_topic, remaining_records) result.append({auto_name: count}) topic_names.append(auto_name) remaining_records -= count assigned_count += count auto_index += 1 - + # Final validation if num_records is not None and assigned_count != num_records: print(f"Warning: Assigned {assigned_count} records, expected {num_records}") - + flatten_topic_list = [] for item in result: for key, count in item.items(): - flatten_topic_list.extend([{'topic': key}] * count) + flatten_topic_list.extend([{"topic": key}] * count) return flatten_topic_list @@ -238,55 +236,40 @@ async def prepare_topic( # Example 1: Dictionary topics with additional generation print("\nExample 1: Dictionary topics + generation") - topics1 = [{'topic1': 20}, {'topic2': 30}] - result1 = asyncio.run(prepare_topic( - topics=topics1, - num_records=100, - records_per_topic=25, - user_instructions="some context" - )) + topics1 = [{"topic1": 20}, {"topic2": 30}] + result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instructions="some context")) print(f"Result: {result1}") print(f"Total: {len(result1)}") # Example 2: String topics with even distribution print("\nExample 2: String topics with distribution") - topics2 = ['topicA', 'topicB', 'topicC'] - result2 = asyncio.run(prepare_topic( - topics=topics2, - num_records=10 - )) + topics2 = ["topicA", "topicB", "topicC"] + result2 = asyncio.run(prepare_topic(topics=topics2, num_records=10)) print(f"Result: {result2}") print(f"Total: {len(result2)}") # Example 3: Mixed string and dict topics print("\nExample 3: Mixed string/dict topics") - topics3 = ['topicX', {'topicY': 10}] - result3 = asyncio.run(prepare_topic( - topics=topics3, - num_records=30, - user_instructions="mixed topics" - )) + topics3 = ["topicX", {"topicY": 10}] + result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instructions="mixed topics")) print(f"Result: {result3}") print(f"Total: {len(result3)}") # Example 4: String topics with fixed count print("\nExample 4: String topics with fixed count") - topics4 = ['apple', 'banana', 'cherry'] + topics4 = ["apple", "banana", "cherry"] result4 = asyncio.run(prepare_topic(topics=topics4, records_per_topic=15)) print(f"Result: {result4}") print(f"Total: {len(result4)}") # Example 5: No topics, generate all print("\nExample 5: No topics, generate all") + async def run_example5(): - result = await prepare_topic( - topics=None, - num_records=10, - records_per_topic=5, - user_instructions="cloud computing" - ) + result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instructions="cloud computing") print(f"Result: {result}") print(f"Total: {len(result)}") + asyncio.run(run_example5()) - print("\n--- Examples Finished ---") \ No newline at end of file + print("\n--- Examples Finished ---") diff --git a/src/starfish/data_factory/config.py b/src/starfish/data_factory/config.py index 8fe1148..962157b 100644 --- a/src/starfish/data_factory/config.py +++ b/src/starfish/data_factory/config.py @@ -1,2 +1,2 @@ PROGRESS_LOG_INTERVAL = 3 -TASK_RUNNER_TIMEOUT = 30 \ No newline at end of file +TASK_RUNNER_TIMEOUT = 30 diff --git a/src/starfish/data_factory/constants.py b/src/starfish/data_factory/constants.py index ed0285d..f52e4fe 100644 --- a/src/starfish/data_factory/constants.py +++ b/src/starfish/data_factory/constants.py @@ -26,28 +26,27 @@ STORAGE_TYPE_IN_MEMORY = "in_memory" - # Define the function directly in constants to avoid circular imports def get_app_data_dir(): - """Returns a platform-specific directory for application data storage. - + r"""Returns a platform-specific directory for application data storage. + Following platform conventions: - Linux: ~/.local/share/starfish - macOS: ~/Library/Application Support/starfish - Windows: %LOCALAPPDATA%\starfish - + Environment variable STARFISH_LOCAL_STORAGE_DIR can override this location. """ # Allow override through environment variable env_dir = os.environ.get("STARFISH_LOCAL_STORAGE_DIR") if env_dir: return env_dir - + app_name = "starfish" - + # Get user's home directory home = Path.home() - + # Platform-specific paths if sys.platform == "win32": # Windows: Use %LOCALAPPDATA% if available, otherwise construct from home @@ -64,9 +63,10 @@ def get_app_data_dir(): if not xdg_data_home: xdg_data_home = os.path.join(home, ".local", "share") base_dir = os.path.join(xdg_data_home, app_name) - + return base_dir + # Get application database directory APP_DATA_DIR = get_app_data_dir() LOCAL_STORAGE_PATH = os.path.join(APP_DATA_DIR, "db") diff --git a/src/starfish/data_factory/event_loop.py b/src/starfish/data_factory/event_loop.py index 904c5c9..b66a031 100644 --- a/src/starfish/data_factory/event_loop.py +++ b/src/starfish/data_factory/event_loop.py @@ -4,7 +4,6 @@ def run_in_event_loop(coroutine): - try: # This call will raise an RuntimError if there is no event loop running. asyncio.get_running_loop() diff --git a/src/starfish/data_factory/factory.py b/src/starfish/data_factory/factory.py index 813529a..e4a061a 100644 --- a/src/starfish/data_factory/factory.py +++ b/src/starfish/data_factory/factory.py @@ -1,25 +1,35 @@ -import datetime -import uuid import asyncio -import json +import datetime import hashlib +import json +import uuid from functools import wraps +from inspect import Parameter, signature from queue import Queue from typing import Any, Callable, Dict, List -from inspect import signature, Parameter + from rich.progress import Progress, TextColumn -from starfish.data_factory.job_manager import JobManager -from starfish.data_factory.constants import RECORD_STATUS, RUN_MODE_DRY_RUN,LOCAL_STORAGE_URI, STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED, RUN_MODE, RUN_MODE_RE_RUN + +from starfish.common.logger import get_logger from starfish.data_factory.config import PROGRESS_LOG_INTERVAL, TASK_RUNNER_TIMEOUT -from starfish.data_factory.storage.local.local_storage import LocalStorage -from starfish.data_factory.storage.in_memory.in_memory_storage import InMemoryStorage -from starfish.data_factory.state import MutableSharedState -from starfish.data_factory.storage.models import ( - GenerationMasterJob, - Project, +from starfish.data_factory.constants import ( + LOCAL_STORAGE_URI, + RECORD_STATUS, + RUN_MODE, + RUN_MODE_DRY_RUN, + RUN_MODE_RE_RUN, + STATUS_COMPLETED, + STATUS_DUPLICATE, + STATUS_FAILED, + STATUS_FILTERED, ) +from starfish.data_factory.job_manager import JobManager +from starfish.data_factory.state import MutableSharedState +from starfish.data_factory.storage.in_memory.in_memory_storage import InMemoryStorage +from starfish.data_factory.storage.local.local_storage import LocalStorage +from starfish.data_factory.storage.models import GenerationMasterJob, Project from starfish.data_factory.utils.decorator import async_wrapper -from starfish.common.logger import get_logger + logger = get_logger(__name__) @@ -45,7 +55,7 @@ def __init__( "max_concurrency": max_concurrency, "target_count": target_count, "state": state, - "on_record_complete": on_record_complete, + "on_record_complete": on_record_complete, "on_record_error": on_record_error, "show_progress": show_progress, "task_runner_timeout": task_runner_timeout, @@ -64,12 +74,11 @@ def __call__(self, func: Callable): def wrapper(*args, **kwargs): run_mode = self.job_config.get(RUN_MODE) try: - # Check for master_job_id in kwargs and assign if present if run_mode == RUN_MODE_RE_RUN: self._setup_storage_and_job_manager() self._set_input_data_from_master_job() - + self._update_job_config() elif run_mode == RUN_MODE_DRY_RUN: # dry run mode @@ -88,16 +97,16 @@ def wrapper(*args, **kwargs): # self.project_id = "8de05a58-c8a4-4c10-8c23-568679c88e65" self.master_job_id = str(uuid.uuid4()) self._setup_storage_and_job_manager() - self._save_project() + self._save_project() self._save_request_config() self._log_master_job_start() # Start progress bar before any operations # Process batches and keep progress bar alive self._update_job_config() self._update_master_job_status() - + self._process_batches() - + result = self._process_output() if len(result) == 0: raise ValueError("No records generated") @@ -115,13 +124,15 @@ def wrapper(*args, **kwargs): else: if run_mode != RUN_MODE_DRY_RUN: self.show_job_progress_status() + # Add run method to the wrapped function def run(*args, **kwargs): - if 'master_job_id' in kwargs: + if "master_job_id" in kwargs: # re_run mode - self.master_job_id = kwargs['master_job_id'] + self.master_job_id = kwargs["master_job_id"] self.job_config[RUN_MODE] = RUN_MODE_RE_RUN return wrapper(*args, **kwargs) + def dry_run(*args, **kwargs): self.job_config[RUN_MODE] = RUN_MODE_DRY_RUN return wrapper(*args, **kwargs) @@ -132,7 +143,7 @@ def dry_run(*args, **kwargs): wrapper.state = self.job_config.get("state") self.func = func return wrapper - + def _process_output(self) -> List[Any]: result = [] output = self.job_manager.job_output.queue @@ -140,16 +151,16 @@ def _process_output(self) -> List[Any]: if v.get(RECORD_STATUS) == STATUS_COMPLETED: result.extend(v.get("output")) return result - + def _check_parameter_match(self): - """Check if the parameters of the function match the parameters of the batches""" + """Check if the parameters of the function match the parameters of the batches.""" # Get the parameters of the function # func_params = inspect.signature(func).parameters # # Get the parameters of the batches # batches_params = inspect.signature(batches).parameters - #from inspect import signature, Parameter + # from inspect import signature, Parameter func_sig = signature(self.func) - + # Validate batch items against function parameters batch_item = self.input_data.queue[0] for param_name, param in func_sig.parameters.items(): @@ -158,26 +169,19 @@ def _check_parameter_match(self): continue # Check if required parameter is missing in batch if param_name not in batch_item: - raise TypeError( - f"Batch item is missing required parameter '{param_name}' " - f"for function {self.func.__name__}" - ) + raise TypeError(f"Batch item is missing required parameter '{param_name}' " f"for function {self.func.__name__}") # Check 2: Ensure all batch parameters exist in function signature for batch_param in batch_item.keys(): if batch_param not in func_sig.parameters: - raise TypeError( - f"Batch items contains unexpected parameter '{batch_param}' " - f"not found in function {self.func.__name__}" - ) - - + raise TypeError(f"Batch items contains unexpected parameter '{batch_param}' " f"not found in function {self.func.__name__}") + def _setup_storage_and_job_manager(self): self.storage_setup() - + self.job_manager = JobManager(job_config=self.job_config, storage=self.factory_storage) def _process_batches(self) -> List[Any]: - """Process batches with asyncio""" + """Process batches with asyncio.""" logger.info( f"[JOB PROGRESS] " f"\033[1mJob Started:\033[0m " @@ -185,7 +189,6 @@ def _process_batches(self) -> List[Any]: f"\033[33mLogging progress every {PROGRESS_LOG_INTERVAL} seconds\033[0m" ) return self.job_manager.run_orchestration() - @async_wrapper() async def _save_project(self): @@ -199,7 +202,6 @@ async def _save_request_config(self): config_data = {"generator": "test_generator", "parameters": {"num_records": 10, "complexity": "medium"}, "input_data": list(self.input_data.queue)} self.config_ref = await self.factory_storage.save_request_config(self.master_job_id, config_data) logger.debug(f" - Saved request config to: {self.config_ref}") - @async_wrapper() async def _set_input_data_from_master_job(self): @@ -207,8 +209,8 @@ async def _set_input_data_from_master_job(self): if master_job: master_job_config_data = await self.factory_storage.get_request_config(master_job.request_config_ref) # Convert list to dict with count tracking using hash values - input_data = master_job_config_data.get("input_data") - + input_data = master_job_config_data.get("input_data") + input_dict = {} for item in input_data: self.input_data.put(item) @@ -220,15 +222,10 @@ async def _set_input_data_from_master_job(self): if input_data_hash in input_dict: input_dict[input_data_hash]["count"] += 1 else: - input_dict[input_data_hash] = { - "data": item, - "data_str": input_data_str, - "count": 1 - } + input_dict[input_data_hash] = {"data": item, "data_str": input_data_str, "count": 1} self.job_config["input_dict"] = input_dict self.job_config[RUN_MODE] = RUN_MODE_RE_RUN - @async_wrapper() async def _log_master_job_start(self): # Now create the master job @@ -244,7 +241,6 @@ async def _log_master_job_start(self): ) await self.factory_storage.log_master_job_start(master_job) logger.debug(f" - Created master job: {master_job.name} ({self.master_job_id})") - @async_wrapper() async def _update_master_job_status(self): @@ -252,8 +248,6 @@ async def _update_master_job_status(self): await self.factory_storage.update_master_job_status(self.master_job_id, "running", now) logger.debug(" - Updated master job status to: running") - - @async_wrapper() async def _complete_master_job(self): # Complete the master job @@ -263,20 +257,20 @@ async def _complete_master_job(self): if self.err: summary = {} else: - summary = {STATUS_COMPLETED: self.job_manager.completed_count, - STATUS_FILTERED: self.job_manager.filtered_count, - STATUS_DUPLICATE: self.job_manager.duplicate_count, - STATUS_FAILED: self.job_manager.failed_count} + summary = { + STATUS_COMPLETED: self.job_manager.completed_count, + STATUS_FILTERED: self.job_manager.filtered_count, + STATUS_DUPLICATE: self.job_manager.duplicate_count, + STATUS_FAILED: self.job_manager.failed_count, + } if self.factory_storage: await self.factory_storage.log_master_job_end(self.master_job_id, status, summary, now, now) logger.info(f"Master Job {self.master_job_id} has been ended") - @async_wrapper() - async def _close_storage(self): + async def _close_storage(self): if self.factory_storage: - await self.factory_storage.close() - + await self.factory_storage.close() def storage_setup(self): if self.storage == "local": @@ -286,8 +280,6 @@ def storage_setup(self): self.factory_storage = InMemoryStorage() asyncio.run(self.factory_storage.setup()) - - def _update_job_config(self): target_acount = self.job_config.get("target_count") new_target_count = self.input_data.qsize() if target_acount == 0 else target_acount @@ -296,23 +288,26 @@ def _update_job_config(self): { "master_job_id": self.master_job_id, "user_func": self.func, - "job_input_queue": self.input_data, + "job_input_queue": self.input_data, "target_count": new_target_count, RUN_MODE: self.job_config.get(RUN_MODE), "input_dict": self.job_config.get("input_dict"), } ) - def _init_progress_bar(self): - self.progress = Progress( - TextColumn("[bold blue]{task.description}"), - # BarColumn(), - # TaskProgressColumn(), - TextColumn("•"), - TextColumn("{task.fields[status]}"), - # TimeRemainingColumn(), - ) if self.job_config.get("show_progress") else None + self.progress = ( + Progress( + TextColumn("[bold blue]{task.description}"), + # BarColumn(), + # TaskProgressColumn(), + TextColumn("•"), + TextColumn("{task.fields[status]}"), + # TimeRemainingColumn(), + ) + if self.job_config.get("show_progress") + else None + ) # Separate task IDs for each counter self.progress_tasks = { STATUS_COMPLETED: None, @@ -321,28 +316,19 @@ def _init_progress_bar(self): STATUS_DUPLICATE: None, } if self.job_config.get("show_progress"): - #self.progress.start() - #with self.progress_lock: + # self.progress.start() + # with self.progress_lock: target_count = self.job_config.get("target_count") - self.progress_tasks[STATUS_COMPLETED] = self.progress.add_task( - "[green]Completed", total=target_count, status="✅ 0" - ) - self.progress_tasks[STATUS_FAILED] = self.progress.add_task( - "[red]Failed", total=target_count, status="❌ 0" - ) - self.progress_tasks[STATUS_FILTERED] = self.progress.add_task( - "[yellow]Filtered", total=target_count, status="🚫 0" - ) - self.progress_tasks[STATUS_DUPLICATE] = self.progress.add_task( - "[cyan]Duplicated", total=target_count, status="🔁 0" - ) + self.progress_tasks[STATUS_COMPLETED] = self.progress.add_task("[green]Completed", total=target_count, status="✅ 0") + self.progress_tasks[STATUS_FAILED] = self.progress.add_task("[red]Failed", total=target_count, status="❌ 0") + self.progress_tasks[STATUS_FILTERED] = self.progress.add_task("[yellow]Filtered", total=target_count, status="🚫 0") + self.progress_tasks[STATUS_DUPLICATE] = self.progress.add_task("[cyan]Duplicated", total=target_count, status="🔁 0") # self.progress_tasks[STATUS_TOTAL] = self.progress.add_task( # "[white]Attempted", total=target_count, status="📊 0" # ) # self.job_config["progress"] = self.progress # self.job_config["progress_tasks"] = self.progress_tasks - - + def show_job_progress_status(self): target_count = self.job_config.get("target_count") logger.info( @@ -356,7 +342,7 @@ def show_job_progress_status(self): ) # if self.job_config.get("show_progress"): # self.progress.start() - + # for counter_type, task_id in self.progress_tasks.items(): # count = getattr(self.job_manager, f"{counter_type}_count") # emoji = STATUS_MOJO_MAP[counter_type] @@ -364,15 +350,17 @@ def show_job_progress_status(self): # if counter_type != STATUS_COMPLETED: # target_count = self.job_manager.total_count # self.progress.update( - # task_id, - # completed=count, + # task_id, + # completed=count, # status=f"{emoji} {count}/{target_count} ({percentage}%)" # ) # self.progress.stop() -def default_input_converter(data : List[Dict[str, Any]]=[], **kwargs) -> Queue[Dict[str, Any]]: +def default_input_converter(data: List[Dict[str, Any]] = None, **kwargs) -> Queue[Dict[str, Any]]: # Determine parallel sources + if data is None: + data = [] parallel_sources = {} if isinstance(data, list) and len(data) > 0: parallel_sources["data"] = data @@ -418,13 +406,30 @@ def data_factory( batch_size: int = 1, target_count: int = 0, max_concurrency: int = 50, - initial_state_values: Dict[str, Any] = {}, - on_record_complete: List[Callable] = [], - on_record_error: List[Callable] = [], + initial_state_values: Dict[str, Any] = None, + on_record_complete: List[Callable] = None, + on_record_error: List[Callable] = None, show_progress: bool = True, input_converter=default_input_converter, task_runner_timeout: int = TASK_RUNNER_TIMEOUT, ): - state = MutableSharedState(initial_state_values) - - return DataFactory(storage, batch_size, max_concurrency, target_count, state, on_record_complete, on_record_error, input_converter=input_converter, show_progress=show_progress, task_runner_timeout=task_runner_timeout) + if on_record_error is None: + on_record_error = [] + if on_record_complete is None: + on_record_complete = [] + if initial_state_values is None: + initial_state_values = {} + state = MutableSharedState(initial_state_values) + + return DataFactory( + storage, + batch_size, + max_concurrency, + target_count, + state, + on_record_complete, + on_record_error, + input_converter=input_converter, + show_progress=show_progress, + task_runner_timeout=task_runner_timeout, + ) diff --git a/src/starfish/data_factory/job_manager.py b/src/starfish/data_factory/job_manager.py index 17ada9c..29f184f 100644 --- a/src/starfish/data_factory/job_manager.py +++ b/src/starfish/data_factory/job_manager.py @@ -2,25 +2,30 @@ import datetime import hashlib import json -from typing import Any, Dict, List import uuid from queue import Queue -from starfish.data_factory.utils.errors import DuplicateRecordError, FilterRecordError, RecordError -from starfish.data_factory.constants import RECORD_STATUS, RUN_MODE, RUN_MODE_RE_RUN, RUN_MODE_DRY_RUN +from typing import Any, Dict, List + +from starfish.common.logger import get_logger from starfish.data_factory.config import PROGRESS_LOG_INTERVAL +from starfish.data_factory.constants import ( + RECORD_STATUS, + RUN_MODE, + RUN_MODE_DRY_RUN, + RUN_MODE_RE_RUN, + STATUS_COMPLETED, + STATUS_DUPLICATE, + STATUS_FAILED, + STATUS_FILTERED, +) from starfish.data_factory.event_loop import run_in_event_loop +from starfish.data_factory.storage.base import Storage +from starfish.data_factory.storage.models import GenerationJob, Record from starfish.data_factory.task_runner import TaskRunner -from starfish.data_factory.constants import STATUS_COMPLETED, STATUS_DUPLICATE, STATUS_FILTERED, STATUS_FAILED -from starfish.data_factory.storage.models import ( - GenerationJob, - GenerationMasterJob, - Project, - Record, -) -from starfish.data_factory.storage.base import Storage -from starfish.common.logger import get_logger + logger = get_logger(__name__) + class JobManager: def __init__(self, job_config: Dict[str, Any], storage: Storage): self.job_config = job_config @@ -43,21 +48,23 @@ def __init__(self, job_config: Dict[str, Any], storage: Storage): async def create_execution_job(self, job_uuid: str, input_data: Dict[str, Any]): logger.debug("\n3. Creating execution job...") input_data_str = json.dumps(input_data) - self.job = GenerationJob(job_id=job_uuid, master_job_id=self.job_config["master_job_id"], - status="running", worker_id="test-worker-1", - run_config=input_data_str, - run_config_hash=hashlib.sha256(input_data_str.encode()).hexdigest()) + self.job = GenerationJob( + job_id=job_uuid, + master_job_id=self.job_config["master_job_id"], + status="running", + worker_id="test-worker-1", + run_config=input_data_str, + run_config_hash=hashlib.sha256(input_data_str.encode()).hexdigest(), + ) await self.storage.log_execution_job_start(self.job) logger.debug(f" - Created execution job: {job_uuid}") - - async def job_save_record_data(self, records, task_status:str, input_data: Dict[str, Any]) -> List[str]: - + async def job_save_record_data(self, records, task_status: str, input_data: Dict[str, Any]) -> List[str]: output_ref_list = [] if self.job_config.get(RUN_MODE) == RUN_MODE_DRY_RUN: return output_ref_list logger.debug("\n5. Saving record data...") - storage_class_name = self.storage.__class__.__name__ + storage_class_name = self.storage.__class__.__name__ if storage_class_name == "LocalStorage": job_uuid = str(uuid.uuid4()) await self.create_execution_job(job_uuid, input_data) @@ -72,35 +79,33 @@ async def job_save_record_data(self, records, task_status:str, input_data: Dict[ record_model = Record(**record) await self.storage.log_record_metadata(record_model) logger.debug(f" - Saved data for record {i}: {output_ref}") - output_ref_list.append(output_ref) + output_ref_list.append(output_ref) await self.complete_execution_job(job_uuid) return output_ref_list - async def complete_execution_job(self,job_uuid: str): + async def complete_execution_job(self, job_uuid: str): logger.debug("\n6. Completing execution job...") now = datetime.datetime.now(datetime.timezone.utc) - counts = {STATUS_COMPLETED: self.completed_count, - STATUS_FILTERED: self.filtered_count, - STATUS_DUPLICATE: self.duplicate_count, - STATUS_FAILED: self.failed_count} + counts = { + STATUS_COMPLETED: self.completed_count, + STATUS_FILTERED: self.filtered_count, + STATUS_DUPLICATE: self.duplicate_count, + STATUS_FAILED: self.failed_count, + } await self.storage.log_execution_job_end(job_uuid, STATUS_COMPLETED, counts, now, now) logger.debug(" - Marked execution job as completed") - + def is_job_to_stop(self) -> bool: - items_check = list(self.job_output.queue)[-1*self.job_run_stop_threshold:] - consecutive_not_completed = len(items_check) == self.job_run_stop_threshold and all( - item[RECORD_STATUS] != STATUS_COMPLETED for item in items_check) - #consecutive_not_completed and - completed_tasks_reach_target = (self.completed_count >= self.target_count) - #total_tasks_reach_target = (self.total_count >= self.target_count) + items_check = list(self.job_output.queue)[-1 * self.job_run_stop_threshold :] + consecutive_not_completed = len(items_check) == self.job_run_stop_threshold and all(item[RECORD_STATUS] != STATUS_COMPLETED for item in items_check) + # consecutive_not_completed and + completed_tasks_reach_target = self.completed_count >= self.target_count + # total_tasks_reach_target = (self.total_count >= self.target_count) return consecutive_not_completed or completed_tasks_reach_target - #return completed_tasks_reach_target or (total_tasks_reach_target and consecutive_not_completed) - - - - + # return completed_tasks_reach_target or (total_tasks_reach_target and consecutive_not_completed) + def run_orchestration(self): - """Process batches with asyncio""" + """Process batches with asyncio.""" self.job_input_queue = self.job_config["job_input_queue"] self.target_count = self.job_config.get("target_count") run_mode = self.job_config.get(RUN_MODE) @@ -112,31 +117,31 @@ def run_orchestration(self): else: run_in_event_loop(self._async_run_orchestration()) - async def _async_run_orchestration_re_run(self): - input_dict = self.job_config.get("input_dict", {}) - for input_data_hash, input_data in input_dict.items(): - - runned_tasks = await self.storage.list_execution_jobs_by_master_id_and_config_hash(self.job_config["master_job_id"], - input_data_hash, STATUS_COMPLETED) - logger.debug(f"Task already runned, returning output from storage") - # put the runned tasks output to the job output - for task in runned_tasks: - records_metadata = await self.storage.list_record_metadata(self.job_config["master_job_id"], task.job_id) - for record in records_metadata: - record_data = await self.storage.get_record_data(record.output_ref) - output_tmp = {RECORD_STATUS: STATUS_COMPLETED, "output": record_data} - self.job_output.put(output_tmp) - self.total_count += 1 - self.completed_count += 1 - # run the rest of the tasks - logger.debug(f"Task not runned, running task") - for _ in range(input_data["count"] - len(runned_tasks)): - self.job_input_queue.put(input_data["data"]) - #self._create_single_task(input_data["data"]) - await self._async_run_orchestration() + async def _async_run_orchestration_re_run(self): + input_dict = self.job_config.get("input_dict", {}) + for input_data_hash, input_data in input_dict.items(): + runned_tasks = await self.storage.list_execution_jobs_by_master_id_and_config_hash( + self.job_config["master_job_id"], input_data_hash, STATUS_COMPLETED + ) + logger.debug("Task already runned, returning output from storage") + # put the runned tasks output to the job output + for task in runned_tasks: + records_metadata = await self.storage.list_record_metadata(self.job_config["master_job_id"], task.job_id) + for record in records_metadata: + record_data = await self.storage.get_record_data(record.output_ref) + output_tmp = {RECORD_STATUS: STATUS_COMPLETED, "output": record_data} + self.job_output.put(output_tmp) + self.total_count += 1 + self.completed_count += 1 + # run the rest of the tasks + logger.debug("Task not runned, running task") + for _ in range(input_data["count"] - len(runned_tasks)): + self.job_input_queue.put(input_data["data"]) + # self._create_single_task(input_data["data"]) + await self._async_run_orchestration() async def _progress_ticker(self): - """Log a message every 5 seconds""" + """Log a message every 5 seconds.""" while not self.is_job_to_stop(): logger.info( f"[JOB PROGRESS] " @@ -151,19 +156,19 @@ async def _progress_ticker(self): await asyncio.sleep(PROGRESS_LOG_INTERVAL) async def _async_run_orchestration(self): - """Main orchestration loop for the job""" + """Main orchestration loop for the job.""" # Start the ticker task _progress_ticker_task = asyncio.create_task(self._progress_ticker()) # Store all running tasks running_tasks = set() - + try: while not self.is_job_to_stop(): - logger.debug(f"Job is not to stop, checking job input queue") + logger.debug("Job is not to stop, checking job input queue") if not self.job_input_queue.empty(): - logger.debug(f"Job input queue is not empty, acquiring semaphore") + logger.debug("Job input queue is not empty, acquiring semaphore") await self.semaphore.acquire() - logger.debug(f"Semaphore acquired, waiting for task to complete") + logger.debug("Semaphore acquired, waiting for task to complete") input_data = self.job_input_queue.get() task = self._create_single_task(input_data) running_tasks.add(task) @@ -177,7 +182,7 @@ async def _async_run_orchestration(self): await _progress_ticker_task except asyncio.CancelledError: pass - + # Cancel all running tasks # todo whether openai call will close for task in running_tasks: @@ -188,16 +193,15 @@ async def _async_run_orchestration(self): def _create_single_task(self, input_data) -> asyncio.Task: task = asyncio.create_task(self._run_single_task(input_data)) asyncio.create_task(self._handle_task_completion(task)) - logger.debug(f"Task created, waiting for task to complete") + logger.debug("Task created, waiting for task to complete") return task async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: - """Run a single task with error handling and storage""" + """Run a single task with error handling and storage.""" output = [] output_ref = [] task_status = STATUS_COMPLETED try: - output = await self.task_runner.run_task(self.job_config["user_func"], input_data) hooks_output = [] # class based hooks. use semaphore to ensure thread safe @@ -207,7 +211,7 @@ async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: # duplicate filtered need retry task_status = STATUS_DUPLICATE elif hooks_output.count(STATUS_FILTERED) > 0: - task_status = STATUS_FILTERED + task_status = STATUS_FILTERED output_ref = await self.job_save_record_data(output.copy(), task_status, input_data) except Exception as e: logger.error(f"Error running task: {e}") @@ -219,10 +223,10 @@ async def _run_single_task(self, input_data) -> List[Dict[str, Any]]: if task_status != STATUS_COMPLETED: logger.debug(f"Task is not completed as {task_status}, putting input data back to the job input queue") self.job_input_queue.put(input_data) - return {RECORD_STATUS: task_status, "output_ref": output_ref,"output":output} + return {RECORD_STATUS: task_status, "output_ref": output_ref, "output": output} async def _handle_task_completion(self, task): - """Handle task completion and update counters""" + """Handle task completion and update counters.""" result = await task async with self.lock: self.job_output.put(result) @@ -237,16 +241,14 @@ async def _handle_task_completion(self, task): self.filtered_count += 1 else: self.failed_count += 1 - #await self._update_progress(task_status, STATUS_MOJO_MAP[task_status]) + # await self._update_progress(task_status, STATUS_MOJO_MAP[task_status]) self.semaphore.release() - def _prepare_next_input(self): - """Prepare input data for the next task""" + """Prepare input data for the next task.""" # Implementation depends on specific use case pass def update_job_config(self, job_config: Dict[str, Any]): - """Update job config by merging new values with existing config""" + """Update job config by merging new values with existing config.""" self.job_config = {**self.job_config, **job_config} - diff --git a/src/starfish/data_factory/state.py b/src/starfish/data_factory/state.py index fcc64e0..4ea8c0c 100644 --- a/src/starfish/data_factory/state.py +++ b/src/starfish/data_factory/state.py @@ -1,10 +1,14 @@ import threading +from typing import Any, Dict, Optional + from pydantic import BaseModel -from typing import Dict, Any, Optional -#from starfish.data_factory.utils.decorator import async_to_sync_event_loop + + +# from starfish.data_factory.utils.decorator import async_to_sync_event_loop class MutableSharedState(BaseModel): _data: Dict[str, Any] = {} - #If you want each MutableSharedState instance to have its own independent + + # If you want each MutableSharedState instance to have its own independent # synchronization, you should move the lock initialization into __init__. def __init__(self, initial_data: Optional[Dict[str, Any]] = None): super().__init__() @@ -22,26 +26,25 @@ def data(self, value: Dict[str, Any]) -> None: with self._lock: self._data = value.copy() - def get(self, key: str) -> Any: with self._lock: return self._data.get(key) - + def set(self, key: str, value: Any) -> None: with self._lock: self._data[key] = value - def update(self, updates: Dict[str, Any]) -> None: with self._lock: self._data.update(updates) # Use to_dict when you want to emphasize you're converting/serializing the state - + def to_dict(self) -> Dict[str, Any]: with self._lock: return self._data.copy() + # # Set the entire state # state.data = {"key": "value"} diff --git a/src/starfish/data_factory/storage/base.py b/src/starfish/data_factory/storage/base.py index c897da8..ff12559 100644 --- a/src/starfish/data_factory/storage/base.py +++ b/src/starfish/data_factory/storage/base.py @@ -157,7 +157,7 @@ async def get_records_for_master_job( async def count_records_for_master_job(self, master_job_id: str, status_filter: Optional[List[StatusRecord]] = None) -> Dict[str, int]: """Efficiently get counts of records grouped by status for a master job.""" pass - + @abstractmethod async def list_record_metadata(self, master_job_uuid: str, job_uuid: str) -> List[Record]: """Retrieve metadata for records belonging to a master job.""" diff --git a/src/starfish/data_factory/storage/in_memory/in_memory_storage.py b/src/starfish/data_factory/storage/in_memory/in_memory_storage.py index bda3257..2b2550f 100644 --- a/src/starfish/data_factory/storage/in_memory/in_memory_storage.py +++ b/src/starfish/data_factory/storage/in_memory/in_memory_storage.py @@ -23,7 +23,7 @@ class InMemoryStorage(Storage): capabilities: Set[str] = {} def __init__(self): - logger.info(f"Initializing InMemoryStorage ") + logger.info("Initializing InMemoryStorage ") self._is_setup = False async def setup(self) -> None: @@ -124,7 +124,7 @@ async def count_records_for_master_job(self, master_job_id: str, status_filter: async def list_record_metadata(self, master_job_uuid: str, job_uuid: str) -> List[Record]: """Retrieve metadata for records belonging to a master job.""" pass - + async def list_execution_jobs_by_master_id_and_config_hash(self, master_job_id: str, config_hash: str, job_status: str) -> Optional[GenerationJob]: """Retrieve execution job details by master job id and config hash.""" - pass \ No newline at end of file + pass diff --git a/src/starfish/data_factory/storage/local/data_handler.py b/src/starfish/data_factory/storage/local/data_handler.py index df74942..ce9e835 100644 --- a/src/starfish/data_factory/storage/local/data_handler.py +++ b/src/starfish/data_factory/storage/local/data_handler.py @@ -93,5 +93,3 @@ async def save_record_data_impl(self, record_uid: str, data: Dict[str, Any]) -> async def get_record_data_impl(self, output_ref: str) -> Dict[str, Any]: return await self._read_json_file(output_ref) # Assumes ref is absolute path - - \ No newline at end of file diff --git a/src/starfish/data_factory/storage/local/local_storage.py b/src/starfish/data_factory/storage/local/local_storage.py index 7fd56f9..c67c169 100644 --- a/src/starfish/data_factory/storage/local/local_storage.py +++ b/src/starfish/data_factory/storage/local/local_storage.py @@ -2,34 +2,34 @@ import datetime import logging import os -from typing import Dict, Any, Optional, List, Set +from typing import Any, Dict, List, Optional, Set from starfish.data_factory.storage.base import Storage, register_storage -from starfish.data_factory.storage.models import ( # Import Pydantic models - Project, - GenerationMasterJob, +from starfish.data_factory.storage.local.data_handler import FileSystemDataHandler +from starfish.data_factory.storage.local.metadata_handler import SQLiteMetadataHandler +from starfish.data_factory.storage.local.utils import parse_uri_to_path +from starfish.data_factory.storage.models import ( # Import Pydantic models GenerationJob, + GenerationMasterJob, + Project, Record, - StatusRecord + StatusRecord, ) -from starfish.data_factory.storage.local.metadata_handler import SQLiteMetadataHandler -from starfish.data_factory.storage.local.data_handler import FileSystemDataHandler -from starfish.data_factory.storage.local.utils import parse_uri_to_path - logger = logging.getLogger(__name__) + class LocalStorage(Storage): - """ - Hybrid Local Storage Backend using SQLite for metadata and local JSON files + """Hybrid Local Storage Backend using SQLite for metadata and local JSON files for data artifacts and large configurations. Facade over internal handlers. """ + capabilities: Set[str] = {"QUERY_METADATA", "FILTER_STATUS", "STORE_LARGE_CONFIG"} def __init__(self, storage_uri: str, data_storage_uri_override: Optional[str] = None): logger.info(f"Initializing LocalStorage with URI: {storage_uri}") self.base_path = parse_uri_to_path(storage_uri) - self.metadata_db_path = os.path.join(self.base_path, "metadata.db") # Consistent name + self.metadata_db_path = os.path.join(self.base_path, "metadata.db") # Consistent name if data_storage_uri_override: if not data_storage_uri_override.startswith("file://"): @@ -48,7 +48,8 @@ def __init__(self, storage_uri: str, data_storage_uri_override: Optional[str] = async def setup(self) -> None: """Initializes both metadata DB schema and base file directories.""" - if self._is_setup: return + if self._is_setup: + return logger.info("Setting up LocalStorage...") await self._metadata_handler.initialize_schema() await self._data_handler.ensure_base_dirs() @@ -92,28 +93,42 @@ async def list_projects(self, limit: Optional[int] = None, offset: Optional[int] async def log_master_job_start(self, job_data: GenerationMasterJob) -> None: await self._metadata_handler.log_master_job_start_impl(job_data) - async def log_master_job_end(self, master_job_id: str, final_status: str, summary: Optional[Dict[str, Any]], end_time: datetime.datetime, update_time: datetime.datetime) -> None: + async def log_master_job_end( + self, master_job_id: str, final_status: str, summary: Optional[Dict[str, Any]], end_time: datetime.datetime, update_time: datetime.datetime + ) -> None: await self._metadata_handler.log_master_job_end_impl(master_job_id, final_status, summary, end_time, update_time) async def update_master_job_status(self, master_job_id: str, status: str, update_time: datetime.datetime) -> None: - await self._metadata_handler.update_master_job_status_impl(master_job_id, status, update_time) + await self._metadata_handler.update_master_job_status_impl(master_job_id, status, update_time) async def get_master_job(self, master_job_id: str) -> Optional[GenerationMasterJob]: return await self._metadata_handler.get_master_job_impl(master_job_id) - async def list_master_jobs(self, project_id: Optional[str] = None, status_filter: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[GenerationMasterJob]: - return await self._metadata_handler.list_master_jobs_impl(project_id, status_filter, limit, offset) + async def list_master_jobs( + self, project_id: Optional[str] = None, status_filter: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None + ) -> List[GenerationMasterJob]: + return await self._metadata_handler.list_master_jobs_impl(project_id, status_filter, limit, offset) async def log_execution_job_start(self, job_data: GenerationJob) -> None: await self._metadata_handler.log_execution_job_start_impl(job_data) - async def log_execution_job_end(self, job_id: str, final_status: str, counts: Dict[str, int], end_time: datetime.datetime, update_time: datetime.datetime, error_message: Optional[str] = None) -> None: + async def log_execution_job_end( + self, + job_id: str, + final_status: str, + counts: Dict[str, int], + end_time: datetime.datetime, + update_time: datetime.datetime, + error_message: Optional[str] = None, + ) -> None: await self._metadata_handler.log_execution_job_end_impl(job_id, final_status, counts, end_time, update_time, error_message) async def get_execution_job(self, job_id: str) -> Optional[GenerationJob]: return await self._metadata_handler.get_execution_job_impl(job_id) - async def list_execution_jobs(self, master_job_id: str, status_filter: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[GenerationJob]: + async def list_execution_jobs( + self, master_job_id: str, status_filter: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None + ) -> List[GenerationJob]: return await self._metadata_handler.list_execution_jobs_impl(master_job_id, status_filter, limit, offset) async def log_record_metadata(self, record_data: Record) -> None: @@ -122,18 +137,21 @@ async def log_record_metadata(self, record_data: Record) -> None: async def get_record_metadata(self, record_uid: str) -> Optional[Record]: return await self._metadata_handler.get_record_metadata_impl(record_uid) - async def get_records_for_master_job( self, master_job_id: str, status_filter: Optional[List[StatusRecord]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Record]: + async def get_records_for_master_job( + self, master_job_id: str, status_filter: Optional[List[StatusRecord]] = None, limit: Optional[int] = None, offset: Optional[int] = None + ) -> List[Record]: return await self._metadata_handler.get_records_for_master_job_impl(master_job_id, status_filter, limit, offset) async def count_records_for_master_job(self, master_job_id: str, status_filter: Optional[List[StatusRecord]] = None) -> Dict[str, int]: - return await self._metadata_handler.count_records_for_master_job_impl(master_job_id, status_filter) - + return await self._metadata_handler.count_records_for_master_job_impl(master_job_id, status_filter) + async def list_execution_jobs_by_master_id_and_config_hash(self, master_job_id: str, config_hash: str, job_status: str) -> List[GenerationJob]: return await self._metadata_handler.list_execution_jobs_by_master_id_and_config_hash_impl(master_job_id, config_hash, job_status) - + async def list_record_metadata(self, master_job_uuid: str, job_uuid: str) -> List[Record]: return await self._metadata_handler.list_record_metadata_impl(master_job_uuid, job_uuid) + @register_storage("local") def create_local_storage(storage_uri: str, data_storage_uri_override: Optional[str] = None) -> LocalStorage: - return LocalStorage(storage_uri, data_storage_uri_override) \ No newline at end of file + return LocalStorage(storage_uri, data_storage_uri_override) diff --git a/src/starfish/data_factory/storage/local/metadata_handler.py b/src/starfish/data_factory/storage/local/metadata_handler.py index 47f29a5..7b6517e 100644 --- a/src/starfish/data_factory/storage/local/metadata_handler.py +++ b/src/starfish/data_factory/storage/local/metadata_handler.py @@ -6,10 +6,14 @@ import os from typing import Any, Dict, List, Optional, Tuple -from starfish.data_factory.constants import STATUS_COMPLETED, STATUS_FAILED, STATUS_FILTERED, STATUS_DUPLICATE - import aiosqlite +from starfish.data_factory.constants import ( + STATUS_COMPLETED, + STATUS_DUPLICATE, + STATUS_FAILED, + STATUS_FILTERED, +) from starfish.data_factory.storage.local.setup import ( initialize_db_schema, # Import setup function ) @@ -451,8 +455,8 @@ async def list_execution_jobs_by_master_id_and_config_hash_impl(self, master_job sql = "SELECT * FROM GenerationJob WHERE master_job_id = ? AND run_config_hash = ? AND status = ?" rows = await self._fetchall_sql(sql, (master_job_id, config_hash, job_status)) return [_row_to_pydantic(GenerationJob, row) for row in rows] if rows else [] - + async def list_record_metadata_impl(self, master_job_uuid: str, job_uuid: str) -> List[Record]: sql = "SELECT * FROM Records WHERE master_job_id = ? AND job_id = ?" rows = await self._fetchall_sql(sql, (master_job_uuid, job_uuid)) - return [_row_to_pydantic(Record, row) for row in rows] \ No newline at end of file + return [_row_to_pydantic(Record, row) for row in rows] diff --git a/src/starfish/data_factory/storage/models.py b/src/starfish/data_factory/storage/models.py index 205720f..f0ce061 100644 --- a/src/starfish/data_factory/storage/models.py +++ b/src/starfish/data_factory/storage/models.py @@ -55,7 +55,7 @@ class GenerationMasterJob(BaseModel): last_update_time: datetime.datetime = Field(default_factory=utc_now, description="Last modification time.") @field_validator("output_schema", mode="before") - def _parse_json_string(cls, value): + def _parse_json_string(self, value): if isinstance(value, str): try: return json.loads(value) @@ -88,7 +88,7 @@ class GenerationJob(BaseModel): error_message: Optional[str] = Field(None, description="Error if the whole execution run failed.") @field_validator("run_config", mode="before") - def _parse_json_string(cls, value): + def _parse_json_string(self, value): # Same validator as above if stored as JSON string in DB if isinstance(value, str): try: diff --git a/src/starfish/data_factory/task_runner.py b/src/starfish/data_factory/task_runner.py index f956f2f..ae2e812 100644 --- a/src/starfish/data_factory/task_runner.py +++ b/src/starfish/data_factory/task_runner.py @@ -1,19 +1,22 @@ import asyncio import time from typing import Any, Callable, Dict, List -from starfish.data_factory.config import TASK_RUNNER_TIMEOUT + from starfish.common.logger import get_logger +from starfish.data_factory.config import TASK_RUNNER_TIMEOUT + logger = get_logger(__name__) -#from starfish.common.logger_new import logger + + +# from starfish.common.logger_new import logger class TaskRunner: def __init__(self, max_retries: int = 1, timeout: int = TASK_RUNNER_TIMEOUT, master_job_id: str = None): self.max_retries = max_retries self.timeout = timeout self.master_job_id = master_job_id - async def run_task(self, func: Callable, input_data: Dict) -> List[Any]: - """Process a single task with asyncio""" + """Process a single task with asyncio.""" retries = 0 start_time = time.time() # maybe better to use retries in a single request instead in the batch level. @@ -28,6 +31,6 @@ async def run_task(self, func: Callable, input_data: Dict) -> List[Any]: except Exception as e: retries += 1 if retries > self.max_retries: - #logger.error(f"Task execution failed after {self.max_retries} retries") + # logger.error(f"Task execution failed after {self.max_retries} retries") raise e await asyncio.sleep(2**retries) # exponential backoff diff --git a/src/starfish/data_factory/utils/decorator.py b/src/starfish/data_factory/utils/decorator.py index 4cb3282..6473ea7 100644 --- a/src/starfish/data_factory/utils/decorator.py +++ b/src/starfish/data_factory/utils/decorator.py @@ -1,22 +1,31 @@ -from typing import Callable import asyncio -#from starfish.data_factory.constants import STORAGE_TYPE_LOCAL +from typing import Callable + +# from starfish.data_factory.constants import STORAGE_TYPE_LOCAL from starfish.data_factory.event_loop import run_in_event_loop + def async_wrapper(): - """Decorator to handle storage-specific async operations""" + """Decorator to handle storage-specific async operations.""" + # to be replaced by the registery pattern def decorator(func: Callable): def wrapper(self, *args, **kwargs): - #if self.storage == STORAGE_TYPE_LOCAL: + # if self.storage == STORAGE_TYPE_LOCAL: return asyncio.run(func(self, *args, **kwargs)) + return wrapper + return decorator + def async_to_sync_event_loop(): - """Decorator to handle storage-specific async operations""" + """Decorator to handle storage-specific async operations.""" + def decorator(func: Callable): def wrapper(self, *args, **kwargs): return run_in_event_loop(func(self, *args, **kwargs)) + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/src/starfish/data_factory/utils/enums.py b/src/starfish/data_factory/utils/enums.py index 6e124d9..0cb4fd9 100644 --- a/src/starfish/data_factory/utils/enums.py +++ b/src/starfish/data_factory/utils/enums.py @@ -1,7 +1,8 @@ from enum import Enum + class RecordStatus(Enum): COMPLETED = "completed" DUPLICATE = "duplicate" FILTERED = "filtered" - FAILED = "failed" \ No newline at end of file + FAILED = "failed" diff --git a/src/starfish/data_factory/utils/errors.py b/src/starfish/data_factory/utils/errors.py index d089131..50ff430 100644 --- a/src/starfish/data_factory/utils/errors.py +++ b/src/starfish/data_factory/utils/errors.py @@ -1,12 +1,13 @@ class DuplicateRecordError(Exception): - """Raised when a record is identified as a duplicate""" + """Raised when a record is identified as a duplicate.""" def __init__(self, message="Duplicate record detected"): self.message = message super().__init__(self.message) + class RecordError(Exception): - """Raised when a record is not processed successfully""" + """Raised when a record is not processed successfully.""" def __init__(self, message="Record not processed successfully"): self.message = message @@ -14,7 +15,7 @@ def __init__(self, message="Record not processed successfully"): class FilterRecordError(Exception): - """Raised when a record is filtered out based on business rules""" + """Raised when a record is filtered out based on business rules.""" def __init__(self, message="Record filtered by business rules"): self.message = message diff --git a/src/starfish/data_factory/utils/mock.py b/src/starfish/data_factory/utils/mock.py index 4a7749a..bfb7fd3 100644 --- a/src/starfish/data_factory/utils/mock.py +++ b/src/starfish/data_factory/utils/mock.py @@ -1,16 +1,19 @@ -import random import asyncio +import random + from starfish.common.logger import get_logger + logger = get_logger(__name__) + async def mock_llm_call(city_name, num_records_per_city, fail_rate=0.01, sleep_time=0.5): await asyncio.sleep(sleep_time) if random.random() < fail_rate: logger.debug(f" {city_name}: Failed!") raise ValueError(f"Mock LLM failed to process city: {city_name}") - + logger.debug(f"{city_name}: Successfully processed!") result = [{"answer": f"{city_name}_{random.randint(1, 5)}"} for _ in range(num_records_per_city)] - return result \ No newline at end of file + return result diff --git a/src/starfish/llm/backend/ollama_adapter.py b/src/starfish/llm/backend/ollama_adapter.py index 0c09755..3b05c8c 100644 --- a/src/starfish/llm/backend/ollama_adapter.py +++ b/src/starfish/llm/backend/ollama_adapter.py @@ -1,4 +1,4 @@ -"""Ollama adapter""" +"""Ollama adapter.""" import asyncio import os @@ -20,25 +20,25 @@ class OllamaError(Exception): - """Base exception for Ollama-related errors""" + """Base exception for Ollama-related errors.""" pass class OllamaNotInstalledError(OllamaError): - """Error raised when Ollama is not installed""" + """Error raised when Ollama is not installed.""" pass class OllamaConnectionError(OllamaError): - """Error raised when connection to Ollama server fails""" + """Error raised when connection to Ollama server fails.""" pass async def is_ollama_running() -> bool: - """Check if Ollama server is running""" + """Check if Ollama server is running.""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{OLLAMA_BASE_URL}/api/version", timeout=aiohttp.ClientTimeout(total=2)) as response: @@ -48,7 +48,7 @@ async def is_ollama_running() -> bool: async def start_ollama_server() -> bool: - """Start the Ollama server if it's not already running""" + """Start the Ollama server if it's not already running.""" # Check if already running if await is_ollama_running(): logger.info("Ollama server is already running") @@ -89,7 +89,7 @@ async def start_ollama_server() -> bool: async def list_models() -> List[Dict[str, Any]]: - """List available models in Ollama using the API""" + """List available models in Ollama using the API.""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{OLLAMA_BASE_URL}/api/tags") as response: @@ -105,7 +105,7 @@ async def list_models() -> List[Dict[str, Any]]: async def is_model_available(model_name: str) -> bool: - """Check if model is available using the CLI command""" + """Check if model is available using the CLI command.""" try: # Use CLI for more reliable checking ollama_bin = shutil.which("ollama") @@ -243,7 +243,7 @@ async def ensure_model_ready(model_name: str) -> bool: async def stop_ollama_server() -> bool: - """Stop the Ollama server""" + """Stop the Ollama server.""" try: # Find the ollama executable (just to check if it's installed) ollama_bin = shutil.which("ollama") @@ -316,7 +316,7 @@ async def stop_ollama_server() -> bool: async def delete_model(model_name: str) -> bool: - """Delete a model from Ollama + """Delete a model from Ollama. Args: model_name: The name of the model to delete diff --git a/src/starfish/llm/model_hub/huggingface_adapter.py b/src/starfish/llm/model_hub/huggingface_adapter.py index 16710d2..fb894b9 100644 --- a/src/starfish/llm/model_hub/huggingface_adapter.py +++ b/src/starfish/llm/model_hub/huggingface_adapter.py @@ -1,5 +1,6 @@ """HuggingFace service for interacting with the HuggingFace API. -This service focuses on model discovery, search, and downloading from HuggingFace.""" +This service focuses on model discovery, search, and downloading from HuggingFace. +""" import asyncio import os @@ -9,10 +10,11 @@ import aiohttp +from starfish.common.logger import get_logger + ##TODO we will need to move the dependencies of ollma to a seperate file so we can support other model hosting providers like vllm. but for now it is fine from starfish.llm.backend.ollama_adapter import delete_model as delete_ollama_model from starfish.llm.backend.ollama_adapter import is_model_available -from starfish.common.logger import get_logger logger = get_logger(__name__) @@ -23,25 +25,25 @@ # HuggingFace Exception Types ############################################# class HuggingFaceError(Exception): - """Base exception for HuggingFace-related errors""" + """Base exception for HuggingFace-related errors.""" pass class HuggingFaceAuthError(HuggingFaceError): - """Error raised when authentication is required but missing""" + """Error raised when authentication is required but missing.""" pass class HuggingFaceModelNotFoundError(HuggingFaceError): - """Error raised when a model is not found""" + """Error raised when a model is not found.""" pass class HuggingFaceAPIError(HuggingFaceError): - """Error raised for general API errors""" + """Error raised for general API errors.""" pass @@ -50,12 +52,12 @@ class HuggingFaceAPIError(HuggingFaceError): # Core HuggingFace API Functions ############################################# def get_hf_token() -> Optional[str]: - """Get HuggingFace API token from environment variable""" + """Get HuggingFace API token from environment variable.""" return os.environ.get("HUGGING_FACE_HUB_TOKEN") async def _make_hf_request(url: str, params: Optional[Dict] = None, check_auth: bool = True) -> Tuple[bool, Any]: - """Make a request to HuggingFace API with proper error handling + """Make a request to HuggingFace API with proper error handling. Args: url: API URL to request @@ -85,7 +87,7 @@ async def _make_hf_request(url: str, params: Optional[Dict] = None, check_auth: async def list_hf_models(query: str = "", limit: int = 20) -> List[Dict[str, Any]]: - """List/search models on HuggingFace + """List/search models on HuggingFace. Args: query: Optional search query @@ -121,7 +123,7 @@ async def list_hf_models(query: str = "", limit: int = 20) -> List[Dict[str, Any async def get_imported_hf_models() -> List[str]: - """Get list of HuggingFace models that have been imported to Ollama + """Get list of HuggingFace models that have been imported to Ollama. Returns: List of model names in Ollama that originated from HuggingFace @@ -133,7 +135,7 @@ async def get_imported_hf_models() -> List[str]: async def check_model_exists(model_id: str) -> bool: - """Check if a model exists on HuggingFace + """Check if a model exists on HuggingFace. Args: model_id: HuggingFace model ID (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") @@ -146,7 +148,7 @@ async def check_model_exists(model_id: str) -> bool: async def find_gguf_files(model_id: str) -> List[Dict[str, Any]]: - """Find GGUF files available for a HuggingFace model + """Find GGUF files available for a HuggingFace model. Args: model_id: HuggingFace model ID @@ -193,7 +195,7 @@ async def find_gguf_files(model_id: str) -> List[Dict[str, Any]]: async def download_gguf_file(model_id: str, file_path: str, target_path: str) -> bool: - """Download a GGUF file from HuggingFace + """Download a GGUF file from HuggingFace. Args: model_id: HuggingFace model ID @@ -264,7 +266,7 @@ async def download_gguf_file(model_id: str, file_path: str, target_path: str) -> async def import_model_to_ollama(local_file_path: str, model_name: str) -> bool: - """Import a GGUF file into Ollama + """Import a GGUF file into Ollama. Args: local_file_path: Path to the downloaded GGUF file @@ -317,7 +319,7 @@ async def get_best_gguf_file(gguf_files: List[Dict[str, Any]]) -> Optional[Dict[ Prioritizes: 1. Smaller quantization (q4_K > q5_K > q8_0) 2. File size (prefers smaller files for same quantization level) - 3. Avoid huge files unless necessary + 3. Avoid huge files unless necessary. Args: gguf_files: List of GGUF file objects diff --git a/src/starfish/llm/parser/json_builder.py b/src/starfish/llm/parser/json_builder.py index b0ce028..44d29aa 100644 --- a/src/starfish/llm/parser/json_builder.py +++ b/src/starfish/llm/parser/json_builder.py @@ -15,7 +15,7 @@ class SimpleField(BaseModel): required: bool = Field(True, description="Whether the field is required") @field_validator("type") - def validate_field_type(cls, v): + def validate_field_type(self, v): valid_types = ["str", "int", "float", "bool", "list", "dict", "null"] if v not in valid_types: raise ValueError(f"Field type must be one of {valid_types}") @@ -50,7 +50,7 @@ class JsonSchemaBuilder: """ def __init__(self): - """Initialize an empty schema builder""" + """Initialize an empty schema builder.""" self.fields = [] def add_simple_field(self, name: str, field_type: str, description: str = "", required: bool = True) -> None: diff --git a/src/starfish/llm/prompt/prompt_loader.py b/src/starfish/llm/prompt/prompt_loader.py index 0121400..65138c4 100644 --- a/src/starfish/llm/prompt/prompt_loader.py +++ b/src/starfish/llm/prompt/prompt_loader.py @@ -24,7 +24,7 @@ class PromptManager: {{ schema_instruction }} {% else %} -You are asked to generate exactly {{ num_records }} records and please return the data in the following JSON format: +You are asked to generate exactly {{ num_records }} records and please return the data in the following JSON format: {{ schema_instruction }} {% endif %} """ @@ -177,12 +177,12 @@ def render_template(self, variables: Dict[str, Any]) -> str: # Create a copy of variables to avoid modifying the original render_vars = variables.copy() - + # Check for list inputs with priority given to required variables is_list_input = False list_input_variable = None input_list_length = None - + # Check variables in priority order (required first, then optional) for var in list(self.required_vars) + list(self.optional_vars): if var in render_vars and isinstance(render_vars[var], list): @@ -206,7 +206,7 @@ def render_template(self, variables: Dict[str, Any]) -> str: # Add default num_records (always use default value of 1 if not specified) render_vars["num_records"] = render_vars.get("num_records", 1) - + return self._template.render(**render_vars) def construct_messages(self, variables: Dict[str, Any]) -> List[Dict[str, str]]: diff --git a/src/starfish/llm/prompt/prompt_template.py b/src/starfish/llm/prompt/prompt_template.py index 468e823..083e9e4 100644 --- a/src/starfish/llm/prompt/prompt_template.py +++ b/src/starfish/llm/prompt/prompt_template.py @@ -1,17 +1,17 @@ # Complete prompts that need no additional template text COMPLETE_PROMPTS = { "data_gen": """ -You are a data generation expert. Your primary objective is to create +You are a data generation expert. Your primary objective is to create high-quality synthetic data that strictly adheres to the provided guidelines. -The user has provided specific instructions for data generation. +The user has provided specific instructions for data generation. - Carefully analyze the given instructions. - Ensure the generated data aligns with the specified requirements. - Maintain accuracy, coherence, and logical consistency. user_instruction: {{user_instruction}} {% if good_examples %} -The user has provided high-quality reference examples. +The user has provided high-quality reference examples. - Identify patterns, structures, and key characteristics from these examples. - Generate data that maintains a similar style, quality, and relevance. - Ensure variations while preserving meaningful consistency. @@ -19,26 +19,26 @@ {% endif %} {% if bad_examples %} -The following examples represent poor-quality data. +The following examples represent poor-quality data. - Avoid replicating errors, inconsistencies, or undesirable patterns. - Ensure generated data is free from the flaws present in these examples. bad_examples: {{bad_examples}} {% endif %} {% if duplicate_examples %} -The user has specified examples that should not be duplicated. +The user has specified examples that should not be duplicated. - Ensure the generated data remains unique and does not replicate these examples. - Introduce meaningful variations while maintaining quality and consistency. duplicate_examples: {{duplicate_examples}} {% endif %} {% if topic %} -The generated data should be contextually relevant to the given topic: '{{topic}}'. +The generated data should be contextually relevant to the given topic: '{{topic}}'. - Maintain thematic consistency. - Ensure factual accuracy where applicable. {% endif %} -Generate unique and high-quality data points. +Generate unique and high-quality data points. - Ensure diversity in the dataset while maintaining coherence. - Avoid redundant or repetitive entries. """, @@ -47,10 +47,10 @@ # Partial prompts that need to be combined with user-provided content PARTIAL_PROMPTS = { "data_gen": { - "header": """You are a data generation expert. Your primary objective is to create + "header": """You are a data generation expert. Your primary objective is to create high-quality synthetic data that strictly adheres to the provided guidelines.""", "footer": """ - Generate unique and high-quality data points. + Generate unique and high-quality data points. - Ensure diversity in the dataset while maintaining coherence. - Avoid redundant or repetitive entries. """, diff --git a/src/starfish/llm/proxy/litellm_adapter.py b/src/starfish/llm/proxy/litellm_adapter.py index 3ad679d..679e68a 100644 --- a/src/starfish/llm/proxy/litellm_adapter.py +++ b/src/starfish/llm/proxy/litellm_adapter.py @@ -2,11 +2,11 @@ import litellm +from starfish.common.logger import get_logger from starfish.llm.proxy.litellm_adapter_ext import ( OPENAI_COMPATIBLE_PROVIDERS_CONFIG, route_openai_compatible_request, ) -from starfish.common.logger import get_logger logger = get_logger(__name__) @@ -68,7 +68,7 @@ def build_chat_messages(user_instruction: str, system_prompt: Optional[str] = No async def route_ollama_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any: - """Handle Ollama-specific model requests + """Handle Ollama-specific model requests. Args: model_name: The full model name (e.g., "ollama/llama3") @@ -114,7 +114,7 @@ async def route_ollama_request(model_name: str, messages: List[Dict[str, str]], async def route_huggingface_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any: - """Handle HuggingFace model requests by importing into Ollama + """Handle HuggingFace model requests by importing into Ollama. Args: model_name: The full model name (e.g., "hf/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") diff --git a/src/starfish/llm/proxy/litellm_adapter_ext.py b/src/starfish/llm/proxy/litellm_adapter_ext.py index 4080d4f..6c27ff0 100644 --- a/src/starfish/llm/proxy/litellm_adapter_ext.py +++ b/src/starfish/llm/proxy/litellm_adapter_ext.py @@ -32,7 +32,7 @@ def _resolve_config_value(value: Any, description: str) -> Any: """Resolves a configuration value based on the '$' convention. '$VAR_NAME' -> os.getenv('VAR_NAME') - Other -> literal value + Other -> literal value. """ if isinstance(value, str) and value.startswith("$"): # Environment Variable lookup: $VAR_NAME diff --git a/src/starfish/llm/structured_llm.py b/src/starfish/llm/structured_llm.py index 368cf98..843d772 100644 --- a/src/starfish/llm/structured_llm.py +++ b/src/starfish/llm/structured_llm.py @@ -1,11 +1,10 @@ -import json from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from pydantic import BaseModel -from starfish.llm.proxy.litellm_adapter import call_chat_model from starfish.llm.parser import JSONParser, PydanticParser from starfish.llm.prompt import PromptManager, get_partial_prompt +from starfish.llm.proxy.litellm_adapter import call_chat_model from starfish.llm.utils import to_sync T = TypeVar("T") diff --git a/tests/__init__.py b/tests/__init__.py index 21f3a6b..aeab51e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,10 +2,7 @@ import sys # Add workspace folders to PYTHONPATH -workspace_folders = [ - os.path.join(os.path.dirname(__file__), '..', 'src'), - os.path.dirname(os.path.dirname(__file__)) -] +workspace_folders = [os.path.join(os.path.dirname(__file__), "..", "src"), os.path.dirname(os.path.dirname(__file__))] for folder in workspace_folders: if folder not in sys.path: diff --git a/tests/data_factory/storage/local/test_basic_storage.py b/tests/data_factory/storage/local/test_basic_storage.py index e0e5bd1..5d7ec17 100644 --- a/tests/data_factory/storage/local/test_basic_storage.py +++ b/tests/data_factory/storage/local/test_basic_storage.py @@ -5,12 +5,12 @@ import asyncio import datetime +import hashlib import json import os import shutil import traceback import uuid -import hashlib from starfish.data_factory.storage.local.local_storage import LocalStorage from starfish.data_factory.storage.models import ( @@ -82,11 +82,11 @@ async def run_basic_test(): run_config = {"batch_size": 5} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=master_job_id, - status="pending", + job_id=job_id, + master_job_id=master_job_id, + status="pending", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) print(f"✓ Execution job created: {job.job_id}") diff --git a/tests/data_factory/storage/local/test_local_storage.py b/tests/data_factory/storage/local/test_local_storage.py index 0be1aa5..9f9a5a2 100644 --- a/tests/data_factory/storage/local/test_local_storage.py +++ b/tests/data_factory/storage/local/test_local_storage.py @@ -1,9 +1,9 @@ import datetime +import hashlib +import json import os import shutil import uuid -import json -import hashlib import pytest import pytest_asyncio @@ -250,11 +250,11 @@ async def test_execution_job_lifecycle(storage, test_master_job): run_config = {"batch_size": 10} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=test_master_job.master_job_id, - status="pending", + job_id=job_id, + master_job_id=test_master_job.master_job_id, + status="pending", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) # Log job start @@ -291,11 +291,11 @@ async def test_list_execution_jobs(storage, test_master_job): run_config = {"batch": i} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=str(uuid.uuid4()), - master_job_id=test_master_job.master_job_id, - status=status, + job_id=str(uuid.uuid4()), + master_job_id=test_master_job.master_job_id, + status=status, run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) @@ -329,11 +329,11 @@ async def test_record_storage(storage, test_master_job): run_config = {"test": "record_storage"} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=str(uuid.uuid4()), - master_job_id=test_master_job.master_job_id, + job_id=str(uuid.uuid4()), + master_job_id=test_master_job.master_job_id, status="running", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) @@ -377,11 +377,11 @@ async def test_get_records_for_master_job(storage, test_master_job): run_config = {"test": "get_records"} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=str(uuid.uuid4()), - master_job_id=test_master_job.master_job_id, + job_id=str(uuid.uuid4()), + master_job_id=test_master_job.master_job_id, status="running", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) @@ -482,11 +482,11 @@ async def test_complete_workflow(storage): run_config = {"batch_size": 10} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=master_job_id, - status="pending", + job_id=job_id, + master_job_id=master_job_id, + status="pending", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) diff --git a/tests/data_factory/storage/local/test_performance.py b/tests/data_factory/storage/local/test_performance.py index 496234d..170c205 100644 --- a/tests/data_factory/storage/local/test_performance.py +++ b/tests/data_factory/storage/local/test_performance.py @@ -1,13 +1,13 @@ import argparse import asyncio import datetime +import hashlib +import json import os import random import shutil import time import uuid -import json -import hashlib from typing import Any, Dict import pytest @@ -240,7 +240,9 @@ async def test_realistic_workflow( status="pending", worker_id=f"worker-{job_idx % (concurrency * 2)}", # Simulate multiple workers run_config={"start_index": job_idx * batch_size, "count": actual_batch_size, "complexity": complexity}, - run_config_hash=hashlib.sha256(json.dumps({"start_index": job_idx * batch_size, "count": actual_batch_size, "complexity": complexity}).encode()).hexdigest() + run_config_hash=hashlib.sha256( + json.dumps({"start_index": job_idx * batch_size, "count": actual_batch_size, "complexity": complexity}).encode() + ).hexdigest(), ) execution_jobs.append(job) @@ -309,7 +311,6 @@ async def process_job(job_idx, job): # Create a list to hold data save tasks data_save_tasks = [] - record_update_tasks = [] # Process records for this job for i, record in enumerate(job_records): @@ -422,7 +423,6 @@ async def process_job(job_idx, job): # Paginated record retrieval (simulate UI pagination) page_size = 50 page_count = (total_records + page_size - 1) // page_size - page_time_total = 0 # Process page reads concurrently page_tasks = [] @@ -458,7 +458,7 @@ async def process_job(job_idx, job): # Wait for all read tasks to complete for i, record, task in record_read_tasks: read_start = time.time() - record_data = await task + await task read_time = time.time() - read_start data_read_times.append(read_time) print(f" - Record {i} data retrieval: {read_time:.4f}s") @@ -546,7 +546,6 @@ async def test_read_performance(): """ # Configuration total_records = 500 # Can be adjusted - batch_size = 10 complexity = "small" # Use small data for faster test setup print(f"\nRunning read performance test with {total_records} total records") @@ -612,12 +611,12 @@ async def test_read_performance(): run_config = {"test_type": "read_performance"} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=master_job_id, - status="running", + job_id=job_id, + master_job_id=master_job_id, + status="running", worker_id="read-perf-worker", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage_instance.log_execution_job_start(job) @@ -707,7 +706,7 @@ async def test_read_performance(): data_read_tasks.append(storage_instance.get_record_data(record.output_ref)) # Read all data concurrently - all_data = await asyncio.gather(*data_read_tasks) + await asyncio.gather(*data_read_tasks) data_time = time.time() - data_start # Calculate timing for this iteration diff --git a/tests/data_factory/storage/test_storage_main.py b/tests/data_factory/storage/test_storage_main.py index 46d9aa0..ac492bd 100644 --- a/tests/data_factory/storage/test_storage_main.py +++ b/tests/data_factory/storage/test_storage_main.py @@ -6,12 +6,13 @@ import asyncio import datetime +import hashlib +import json import os import shutil import time import uuid -import json -import hashlib + import pytest # Import storage components @@ -97,12 +98,12 @@ async def test_basic_workflow(): run_config = {"test_param": "test_value"} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=master_job_id, - status="pending", + job_id=job_id, + master_job_id=master_job_id, + status="pending", worker_id="test-worker-1", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) print(f" - Created execution job: {job.job_id}") @@ -254,12 +255,12 @@ async def small_performance_test(): run_config = {"job_idx": job_idx, "records_per_job": RECORDS_PER_JOB} run_config_str = json.dumps(run_config) job = GenerationJob( - job_id=job_id, - master_job_id=master_job_id, - status="running", + job_id=job_id, + master_job_id=master_job_id, + status="running", worker_id=f"worker-{job_idx}", run_config=run_config_str, - run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest() + run_config_hash=hashlib.sha256(run_config_str.encode()).hexdigest(), ) await storage.log_execution_job_start(job) diff --git a/tests/data_factory/test_data_factory.py b/tests/data_factory/test_data_factory.py index 701c883..df706ca 100644 --- a/tests/data_factory/test_data_factory.py +++ b/tests/data_factory/test_data_factory.py @@ -1,15 +1,13 @@ -import pytest import nest_asyncio -import asyncio -import random +import pytest from starfish.data_factory.constants import STATUS_COMPLETED nest_asyncio.apply() -from starfish.common.env_loader import load_env_file from starfish import data_factory -from starfish.data_factory.state import MutableSharedState +from starfish.common.env_loader import load_env_file from starfish.data_factory.utils.mock import mock_llm_call + load_env_file() ### Mock LLM call @@ -22,7 +20,7 @@ # if random.random() < fail_rate: # print(f" {city_name}: Failed!") ## For debugging # raise ValueError(f"Mock LLM failed to process city: {city_name}") - + # print(f"{city_name}: Successfully processed!") ## For debugging # result = [f"{city_name}_{random.randint(1, 5)}" for _ in range(num_records_per_city)] @@ -36,19 +34,24 @@ async def test_case_1(): - Broadcast: num_records_per_city - Expected: All cities processed successfully """ + @data_factory(max_concurrency=2) async def test1(city_name, num_records_per_city, fail_rate=0.5, sleep_time=1): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) - result = test1.run(data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - {'city_name': '3. Chicago'}, - {'city_name': '4. Houston'}, - {'city_name': '5. Miami'} - ], num_records_per_city=5) + result = test1.run( + data=[ + {"city_name": "1. New York"}, + {"city_name": "2. Los Angeles"}, + {"city_name": "3. Chicago"}, + {"city_name": "4. Houston"}, + {"city_name": "5. Miami"}, + ], + num_records_per_city=5, + ) assert len(result) == 25 + @pytest.mark.asyncio async def test_case_2(): """Test with kwargs list and broadcast variables @@ -56,13 +59,14 @@ async def test_case_2(): - Broadcast: num_records_per_city - Expected: TypeError due to incorrect input format """ + @data_factory(max_concurrency=2) async def test1(city_name, num_records_per_city, fail_rate=0.5, sleep_time=1): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) with pytest.raises(TypeError): - test1.run(city=["1. New York", "2. Los Angeles", "3. Chicago", "4. Houston", "5. Miami"], - num_records_per_city=5) + test1.run(city=["1. New York", "2. Los Angeles", "3. Chicago", "4. Houston", "5. Miami"], num_records_per_city=5) + @pytest.mark.asyncio async def test_case_3(): @@ -71,57 +75,56 @@ async def test_case_3(): - Parameters: fail_rate=1 (100% failure) - Expected: Exception due to all requests failing """ + @data_factory(max_concurrency=2) async def test1(city_name, num_records_per_city, fail_rate=1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) with pytest.raises(Exception): - test1.run(data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - {'city_name': '3. Chicago'}, - {'city_name': '4. Houston'}, - {'city_name': '5. Miami'} - ], num_records_per_city=5) + test1.run( + data=[ + {"city_name": "1. New York"}, + {"city_name": "2. Los Angeles"}, + {"city_name": "3. Chicago"}, + {"city_name": "4. Houston"}, + {"city_name": "5. Miami"}, + ], + num_records_per_city=5, + ) + @pytest.mark.asyncio async def test_case_4(): """Test if broadcast variables can override kwargs with a single value""" + @data_factory(max_concurrency=2) async def test_func(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) - result = test_func.run( - data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'} - ], - city_name='override_city_name', - num_records_per_city=1 - ) - + result = test_func.run(data=[{"city_name": "1. New York"}, {"city_name": "2. Los Angeles"}], city_name="override_city_name", num_records_per_city=1) + # Verify all results contain the override value for item in result: - assert ('override_city_name' in item['answer']) + assert "override_city_name" in item["answer"] + @pytest.mark.asyncio async def test_case_5(): """Test if broadcast variables can override kwargs with a list of values""" + @data_factory(max_concurrency=2) async def test_func(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) result = test_func.run( - data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'} - ], - city_name=['1. override_city_name', '2. override_city_name'], - num_records_per_city=1 + data=[{"city_name": "1. New York"}, {"city_name": "2. Los Angeles"}], + city_name=["1. override_city_name", "2. override_city_name"], + num_records_per_city=1, ) - + # Verify each result contains the corresponding override value - assert any('1. override_city_name' in item["answer"] or '2. override_city_name' in item["answer"] for item in result) + assert any("1. override_city_name" in item["answer"] or "2. override_city_name" in item["answer"] for item in result) + @pytest.mark.asyncio async def test_case_6(): @@ -130,15 +133,20 @@ async def test_case_6(): - Missing: Required num_records_per_city parameter - Expected: TypeError due to missing required parameter """ + @data_factory(max_concurrency=2) async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) with pytest.raises(TypeError): - test1.run(data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - ], city_name='override_city_name') + test1.run( + data=[ + {"city_name": "1. New York"}, + {"city_name": "2. Los Angeles"}, + ], + city_name="override_city_name", + ) + @pytest.mark.asyncio async def test_case_7(): @@ -147,15 +155,21 @@ async def test_case_7(): - Extra: random_param not defined in workflow - Expected: TypeError due to unexpected parameter """ + @data_factory(max_concurrency=2) async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) with pytest.raises(TypeError): - test1.run(data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - ], num_records_per_city=1, random_param='random_param') + test1.run( + data=[ + {"city_name": "1. New York"}, + {"city_name": "2. Los Angeles"}, + ], + num_records_per_city=1, + random_param="random_param", + ) + @pytest.mark.asyncio async def test_case_8(): @@ -164,18 +178,21 @@ async def test_case_8(): - Hook: test_hook modifies state - Expected: State variable should be modified by hook """ + def test_hook(data, state): - state.update({"variable": f'changed_state - {data}'}) + state.update({"variable": f"changed_state - {data}"}) return STATUS_COMPLETED - - @data_factory(max_concurrency=2, on_record_complete=[test_hook], initial_state_values={'variable': 'initial_state'}) + @data_factory(max_concurrency=2, on_record_complete=[test_hook], initial_state_values={"variable": "initial_state"}) async def test1(city_name, num_records_per_city, fail_rate=0.1, sleep_time=0.05): return await mock_llm_call(city_name, num_records_per_city, fail_rate=fail_rate, sleep_time=sleep_time) - result = test1.run(data=[ - {'city_name': '1. New York'}, - {'city_name': '2. Los Angeles'}, - ], num_records_per_city=1) - state_value = test1.state.get('variable') - assert state_value.startswith('changed_state') \ No newline at end of file + test1.run( + data=[ + {"city_name": "1. New York"}, + {"city_name": "2. Los Angeles"}, + ], + num_records_per_city=1, + ) + state_value = test1.state.get("variable") + assert state_value.startswith("changed_state") diff --git a/tests/llm/prompt/test_prompt.py b/tests/llm/prompt/test_prompt.py index 5e32676..ec412f7 100644 --- a/tests/llm/prompt/test_prompt.py +++ b/tests/llm/prompt/test_prompt.py @@ -1,8 +1,7 @@ import pytest from starfish.llm.prompt import PromptManager -from starfish.llm.prompt.prompt_loader import get_prompt, get_partial_prompt - +from starfish.llm.prompt.prompt_loader import get_partial_prompt, get_prompt """Tests for the PromptManager class and related functionality in starfish.llm.prompt. @@ -14,9 +13,9 @@ 5. Utility methods like get_prompt and get_partial_prompt Note: Although the PromptManager appends a MANDATE_INSTRUCTION that contains -schema_instruction and other variables, it does not treat schema_instruction +schema_instruction and other variables, it does not treat schema_instruction as a required variable. The test cases follow the actual implementation behavior -which identifies only variables used outside of conditional blocks in the +which identifies only variables used outside of conditional blocks in the original user template as "required". """ @@ -24,19 +23,16 @@ # Utility functions for test setup def get_expected_mandate_vars(required=False): """Get the list of variables added by MANDATE_INSTRUCTION. - + Args: required: If True, return only variables that would be classified as required. If False (default), return only variables that would be classified as optional. - + Returns: Set of variable names. """ - all_mandate_vars = { - "is_list_input", "list_input_variable", "input_list_length", - "schema_instruction", "num_records" - } - + all_mandate_vars = {"is_list_input", "list_input_variable", "input_list_length", "schema_instruction", "num_records"} + # In practice, none of the mandate variables are treated as required if required: return set() @@ -50,19 +46,17 @@ def basic_template(): """Simple template with basic variables for testing.""" return "Hello, {{ name }}! Your age is {{ age }}." + @pytest.fixture def simple_prompt_manager(basic_template): """Basic PromptManager instance with a simple template.""" return PromptManager(basic_template) + @pytest.fixture def standard_variables(): """Common variables used in multiple tests.""" - return { - "name": "Alice", - "age": 30, - "schema_instruction": "Test schema" - } + return {"name": "Alice", "age": 30, "schema_instruction": "Test schema"} class TestPromptManager: @@ -80,7 +74,7 @@ def test_basic_required_variables(self, simple_prompt_manager, standard_variable # Basic variables from template expected_template_required = {"name", "age"} expected_template_optional = set() - + # Get expected variables including those from MANDATE_INSTRUCTION expected_required = expected_template_required.union(get_expected_mandate_vars(required=True)) expected_optional = expected_template_optional.union(get_expected_mandate_vars(required=False)) @@ -95,25 +89,27 @@ def test_basic_required_variables(self, simple_prompt_manager, standard_variable result = manager.render_template(standard_variables) assert "Hello, Alice! Your age is 30." in result # Check that part of the non-list mandate instruction is present - assert "You are asked to generate exactly 1 records" in result + assert "You are asked to generate exactly 1 records" in result assert "Test schema" in result - @pytest.mark.parametrize("template,template_required,template_optional,test_name", [ - # Basic conditional test - ( - """ + @pytest.mark.parametrize( + "template,template_required,template_optional,test_name", + [ + # Basic conditional test + ( + """ Hello, {{ name }}! {% if show_age %} Your age is {{ age }}. {% endif %} """, - {"name"}, - {"show_age", "age"}, - "basic_conditional" - ), - # Nested conditional test - ( - """ + {"name"}, + {"show_age", "age"}, + "basic_conditional", + ), + # Nested conditional test + ( + """ Hello, {{ name }}! {% if show_details %} {% if show_age %} @@ -124,55 +120,56 @@ def test_basic_required_variables(self, simple_prompt_manager, standard_variable {% endif %} {% endif %} """, - {"name"}, - {"show_details", "show_age", "age", "show_location", "location"}, - "nested_conditional" - ), - # Mixed conditional test - ( - """ + {"name"}, + {"show_details", "show_age", "age", "show_location", "location"}, + "nested_conditional", + ), + # Mixed conditional test + ( + """ Hello, {{ name }}! - + {% if show_details %} Your details: {{ details }} {% endif %} - + Always show: {{ details }} """, - {"name", "details"}, - {"show_details"}, - "mixed_variables" - ), - # Conditional in conditional test - ( - """ + {"name", "details"}, + {"show_details"}, + "mixed_variables", + ), + # Conditional in conditional test + ( + """ Hello, {{ name }}! - + {% if show_details %} {% if details %} Your details: {{ details }} {% endif %} {% endif %} """, - {"name"}, - {"show_details", "details"}, - "conditional_in_conditional" - ), - ]) + {"name"}, + {"show_details", "details"}, + "conditional_in_conditional", + ), + ], + ) def test_conditional_variable_analysis(self, template, template_required, template_optional, test_name): """Test variable analysis with various conditional structures.""" manager = PromptManager(template) - + # Get the variables all_vars = set(manager.get_all_variables()) required_vars = set(manager.get_required_variables()) optional_vars = set(manager.get_optional_variables()) - + # Get expected variables including those from MANDATE_INSTRUCTION expected_required = template_required.union(get_expected_mandate_vars(required=True)) expected_optional = template_optional.union(get_expected_mandate_vars(required=False)) expected_all = expected_required.union(expected_optional) - + # Check variable identification assert all_vars == expected_all, f"Failed for {test_name}: all variables" assert required_vars == expected_required, f"Failed for {test_name}: required variables" @@ -181,16 +178,12 @@ def test_conditional_variable_analysis(self, template, template_required, templa # Basic rendering test if test_name == "basic_conditional": # Test showing age - result = manager.render_template({ - "name": "Bob", "show_age": True, "age": 25, "schema_instruction": "Age schema" - }).strip() + result = manager.render_template({"name": "Bob", "show_age": True, "age": 25, "schema_instruction": "Age schema"}).strip() assert "Hello, Bob!" in result assert "Your age is 25." in result - + # Test hiding age - result = manager.render_template({ - "name": "Charlie", "show_age": False, "schema_instruction": "No age schema" - }).strip() + result = manager.render_template({"name": "Charlie", "show_age": False, "schema_instruction": "No age schema"}).strip() assert "Hello, Charlie!" in result assert "Your age is" not in result @@ -221,11 +214,8 @@ def test_complex_templates(self): # Template variables template_required = {"var1"} - template_optional = { - "condition1", "nested_condition", "var2", "var3", - "condition2", "var4", "var5", "standalone_condition", "var6" - } - + template_optional = {"condition1", "nested_condition", "var2", "var3", "condition2", "var4", "var5", "standalone_condition", "var6"} + # Get expected variables including those from MANDATE_INSTRUCTION expected_req = template_required.union(get_expected_mandate_vars(required=True)) expected_opt = template_optional.union(get_expected_mandate_vars(required=False)) @@ -241,9 +231,9 @@ def test_complex_templates(self): assert opt_vars == expected_opt # Add a basic render test for completeness - result = manager.render_template({ - "var1": "Value1", "schema_instruction": "Complex Schema", "condition1": False, "condition2": False, "standalone_condition": False - }).strip() + result = manager.render_template( + {"var1": "Value1", "schema_instruction": "Complex Schema", "condition1": False, "condition2": False, "standalone_condition": False} + ).strip() assert "Value1 is shown in else block" in result assert "Value1 appears outside all conditions" in result assert "is in a different conditional" not in result @@ -255,14 +245,15 @@ def test_complex_templates(self): # --------------------------------------------------------------------------- @pytest.mark.parametrize( - "missing_var,variables,error_message", [ + "missing_var,variables,error_message", + [ # Missing required variables ("color", {"name": "Helen", "schema_instruction": "Schema"}, "Required variable 'color' is missing"), ("name", {"color": "Blue", "schema_instruction": "Schema"}, "Required variable 'name' is missing"), # None values for required variables ("color", {"name": "Ivan", "color": None, "schema_instruction": "Schema"}, "Required variable 'color' cannot be None"), ("name", {"name": None, "color": "Green", "schema_instruction": "Schema"}, "Required variable 'name' cannot be None"), - ] + ], ) def test_error_handling(self, missing_var, variables, error_message): """Test error handling for missing or None required variables.""" @@ -271,7 +262,7 @@ def test_error_handling(self, missing_var, variables, error_message): with pytest.raises(ValueError) as exc_info: manager.render_template(variables) - + assert error_message in str(exc_info.value) # --------------------------------------------------------------------------- @@ -284,10 +275,7 @@ def test_list_input_rendering(self): manager = PromptManager(template) items = ["apple", "banana"] - result = manager.render_template({ - "items_to_process": items, - "schema_instruction": "List schema" - }) + result = manager.render_template({"items_to_process": items, "schema_instruction": "List schema"}) # Check original template part # Note: The list itself is replaced by a reference and JSON dump @@ -308,18 +296,11 @@ def test_num_records_rendering(self): manager = PromptManager(template) # Test default num_records = 1 - result_default = manager.render_template({ - "topic": "Weather", - "schema_instruction": "Weather schema" - }) + result_default = manager.render_template({"topic": "Weather", "schema_instruction": "Weather schema"}) assert "You are asked to generate exactly 1 records" in result_default # Test custom num_records = 5 - result_custom = manager.render_template({ - "topic": "Cities", - "schema_instruction": "City schema", - "num_records": 5 - }) + result_custom = manager.render_template({"topic": "Cities", "schema_instruction": "City schema", "num_records": 5}) assert "You are asked to generate exactly 5 records" in result_custom def test_header_footer_rendering(self): @@ -335,11 +316,11 @@ def test_header_footer_rendering(self): assert header in result assert template in result assert footer in result - + # Ensure correct order - header comes before template, template before footer assert result.index(header) < result.index(template) assert result.index(template) < result.index(footer) - + # Verify schema instruction is included assert "Header/Footer schema" in result @@ -366,15 +347,12 @@ def test_construct_messages(self): assert isinstance(messages[0], dict) assert messages[0]["role"] == "user" assert "User query: How does this work?" in messages[0]["content"] - assert "Query schema" in messages[0]["content"] # Mandate part + assert "Query schema" in messages[0]["content"] # Mandate part def test_get_printable_messages(self): """Test the get_printable_messages formatting.""" - manager = PromptManager("") # Empty template, just mandate - messages = [ - {"role": "user", "content": "Test content line 1\nTest content line 2"}, - {"role": "assistant", "content": "Assistant response"} - ] + manager = PromptManager("") # Empty template, just mandate + messages = [{"role": "user", "content": "Test content line 1\nTest content line 2"}, {"role": "assistant", "content": "Assistant response"}] formatted_string = manager.get_printable_messages(messages) assert "========" in formatted_string @@ -390,42 +368,43 @@ def test_get_printable_messages(self): def test_get_prompt(): """Test the get_prompt utility function.""" from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS - + # Get a key from COMPLETE_PROMPTS prompt_name = next(iter(COMPLETE_PROMPTS.keys())) - + # Test retrieving a valid prompt prompt = get_prompt(prompt_name) assert prompt is not None - + # Test cache works (call again) prompt_again = get_prompt(prompt_name) assert prompt is prompt_again # Should be the same object (cached) - + # Test invalid prompt name with pytest.raises(ValueError) as exc_info: get_prompt("nonexistent_prompt_name") - + assert "Unknown complete prompt" in str(exc_info.value) assert prompt_name in str(exc_info.value) # Should list available options + def test_get_partial_prompt(): """Test the get_partial_prompt utility function.""" from starfish.llm.prompt.prompt_template import PARTIAL_PROMPTS - + # Get a key from PARTIAL_PROMPTS prompt_name = next(iter(PARTIAL_PROMPTS.keys())) - + # Test retrieving a valid partial prompt template_str = "Custom template: {{ var }}" prompt_manager = get_partial_prompt(prompt_name, template_str) - + assert isinstance(prompt_manager, PromptManager) assert "var" in prompt_manager.get_all_variables() - + # Test invalid prompt name with pytest.raises(ValueError) as exc_info: get_partial_prompt("nonexistent_prompt_name", template_str) - + assert "Unknown partial prompt" in str(exc_info.value) assert prompt_name in str(exc_info.value) # Should list available options From 577325dfcc40c27243e7809f1f29bf8190ce1acd Mon Sep 17 00:00:00 2001 From: "johnwayne.jiang" Date: Tue, 15 Apr 2025 16:29:03 -0700 Subject: [PATCH 7/7] cicd : add pre-commit and cicd --- src/starfish/data_factory/storage/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/starfish/data_factory/storage/models.py b/src/starfish/data_factory/storage/models.py index f0ce061..205720f 100644 --- a/src/starfish/data_factory/storage/models.py +++ b/src/starfish/data_factory/storage/models.py @@ -55,7 +55,7 @@ class GenerationMasterJob(BaseModel): last_update_time: datetime.datetime = Field(default_factory=utc_now, description="Last modification time.") @field_validator("output_schema", mode="before") - def _parse_json_string(self, value): + def _parse_json_string(cls, value): if isinstance(value, str): try: return json.loads(value) @@ -88,7 +88,7 @@ class GenerationJob(BaseModel): error_message: Optional[str] = Field(None, description="Error if the whole execution run failed.") @field_validator("run_config", mode="before") - def _parse_json_string(self, value): + def _parse_json_string(cls, value): # Same validator as above if stored as JSON string in DB if isinstance(value, str): try: