diff --git a/packages/optimization/src/ldai_optimization/__init__.py b/packages/optimization/src/ldai_optimization/__init__.py index 61c10304..64a50220 100644 --- a/packages/optimization/src/ldai_optimization/__init__.py +++ b/packages/optimization/src/ldai_optimization/__init__.py @@ -10,6 +10,8 @@ AIJudgeCallConfig, GroundTruthOptimizationOptions, GroundTruthSample, + LLMCallConfig, + LLMCallContext, OptimizationContext, OptimizationFromConfigOptions, OptimizationJudge, @@ -28,6 +30,8 @@ 'GroundTruthOptimizationOptions', 'GroundTruthSample', 'LDApiError', + 'LLMCallConfig', + 'LLMCallContext', 'OptimizationClient', 'OptimizationContext', 'OptimizationFromConfigOptions', diff --git a/packages/optimization/src/ldai_optimization/client.py b/packages/optimization/src/ldai_optimization/client.py index 3f33dd04..3ee29736 100644 --- a/packages/optimization/src/ldai_optimization/client.py +++ b/packages/optimization/src/ldai_optimization/client.py @@ -30,8 +30,9 @@ ) from ldai_optimization.ld_api_client import ( AgentOptimizationConfig, + AgentOptimizationResultPatch, + AgentOptimizationResultPost, LDApiClient, - OptimizationResultPayload, ) from ldai_optimization.prompts import ( _acceptance_criteria_implies_duration_optimization, @@ -49,6 +50,26 @@ logger = logging.getLogger(__name__) +def _find_model_config( + model_name: str, configs: List[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """Find the best matching model config for a given model name. + + When multiple configs share the same ``id``, the one marked ``global=True`` + is preferred over project-specific configs. Falls back to the first + non-global match if no global entry exists. + + :param model_name: The model id to look up. + :param configs: List of model config dicts from the LD API. + :return: Best-matching model config dict, or None if no match. + """ + matching = [mc for mc in configs if mc.get("id") == model_name] + if not matching: + return None + global_match = next((mc for mc in matching if mc.get("global") is True), None) + return global_match if global_match is not None else matching[0] + + def _strip_provider_prefix(model: str) -> str: """Strip the provider prefix from a model identifier returned by the LD API. @@ -96,7 +117,7 @@ def _compute_validation_count(pool_size: int) -> int: "generating": {"status": "RUNNING", "activity": "GENERATING"}, "evaluating": {"status": "RUNNING", "activity": "EVALUATING"}, "generating variation": {"status": "RUNNING", "activity": "GENERATING_VARIATION"}, - "validating": {"status": "RUNNING", "activity": "VALIDATING"}, + "validating": {"status": "RUNNING", "activity": "EVALUATING"}, "turn completed": {"status": "RUNNING", "activity": "COMPLETED"}, "success": {"status": "PASSED", "activity": "COMPLETED"}, "failure": {"status": "FAILED", "activity": "COMPLETED"}, @@ -116,6 +137,8 @@ def __init__(self, ldClient: LDAIClient) -> None: self._ldClient = ldClient self._last_run_succeeded: bool = False self._last_succeeded_context: Optional[OptimizationContext] = None + self._last_optimization_result_id: Optional[str] = None + self._initial_tool_keys: List[str] = [] if os.environ.get("LAUNCHDARKLY_API_KEY"): self._has_api_key = True @@ -588,7 +611,7 @@ async def _evaluate_config_judge( judge_ctx = OptimizationJudgeContext( user_input=judge_user_input, - variables=variables or {}, + current_variables=variables or {}, ) _judge_start = time.monotonic() @@ -749,7 +772,7 @@ async def _evaluate_acceptance_judge( judge_ctx = OptimizationJudgeContext( user_input=judge_user_input, - variables=resolved_variables, + current_variables=resolved_variables, ) _judge_start = time.monotonic() @@ -804,6 +827,13 @@ async def _get_agent_config( ) self._initial_instructions = raw_instructions + raw_tools = raw_variation.get("tools", []) + self._initial_tool_keys = [ + t["key"] + for t in raw_tools + if isinstance(t, dict) and "key" in t + ] + agent_config = dataclasses.replace( agent_config, instructions=raw_instructions ) @@ -915,6 +945,7 @@ async def _run_ground_truth_optimization( self._agent_config = agent_config self._last_run_succeeded = False self._last_succeeded_context = None + self._last_optimization_result_id = None self._initialize_class_members_from_config(agent_config) # Seed from the first model choice on the first iteration @@ -1342,16 +1373,22 @@ async def optimize_from_config( config = api_client.get_agent_optimization(options.project_key, optimization_config_key) self._agent_key = config["aiConfigKey"] - optimization_id: str = config["id"] + optimization_key: str = config["key"] run_id = str(uuid.uuid4()) + model_configs: List[Dict[str, Any]] = [] + try: + model_configs = api_client.get_model_configs(options.project_key) + except Exception as exc: + logger.debug("Could not pre-fetch model configs: %s", exc) + context = random.choice(options.context_choices) # _get_agent_config calls _initialize_class_members_from_config internally; # _run_optimization calls it again to reset history before the loop starts. agent_config = await self._get_agent_config(self._agent_key, context) optimization_options = self._build_options_from_config( - config, options, api_client, optimization_id, run_id + config, options, api_client, optimization_key, run_id, model_configs ) if isinstance(optimization_options, GroundTruthOptimizationOptions): result = await self._run_ground_truth_optimization(agent_config, optimization_options) @@ -1359,13 +1396,21 @@ async def optimize_from_config( result = await self._run_optimization(agent_config, optimization_options) if options.auto_commit and self._last_run_succeeded and self._last_succeeded_context: - self._commit_variation( + created_key = self._commit_variation( self._last_succeeded_context, project_key=options.project_key, ai_config_key=config["aiConfigKey"], output_key=options.output_key, api_client=api_client, + model_configs=model_configs, ) + if created_key and self._last_optimization_result_id: + api_client.patch_agent_optimization_result( + options.project_key, + optimization_key, + self._last_optimization_result_id, + {"createdVariationKey": created_key}, + ) return result def _build_options_from_config( @@ -1373,8 +1418,9 @@ def _build_options_from_config( config: AgentOptimizationConfig, options: OptimizationFromConfigOptions, api_client: LDApiClient, - optimization_id: str, + optimization_key: str, run_id: str, + model_configs: Optional[List[Dict[str, Any]]] = None, ) -> "Union[OptimizationOptions, GroundTruthOptimizationOptions]": """Map a fetched AgentOptimization config + user options into the appropriate options type. @@ -1391,8 +1437,9 @@ def _build_options_from_config( :param config: Validated AgentOptimizationConfig from the API. :param options: User-provided options from optimize_from_config. :param api_client: Initialised LDApiClient for result persistence. - :param optimization_id: UUID id of the parent agent_optimization record. + :param optimization_key: String key of the parent agent_optimization record. :param run_id: UUID that groups all result records for this run. + :param model_configs: Pre-fetched list of model config dicts for resolving modelConfigKey. :return: OptimizationOptions or GroundTruthOptimizationOptions. """ judges: Dict[str, OptimizationJudge] = {} @@ -1421,6 +1468,31 @@ def _build_options_from_config( project_key = options.project_key config_version: int = config["version"] + _cached_model_configs: List[Dict[str, Any]] = list(model_configs or []) + + # Maps logical iteration number → result record id. Each new main-loop + # iteration (plus the init iteration 0) POSTs a fresh record; subsequent + # status events for that same iteration PATCH the existing record. + _iteration_result_ids: Dict[int, str] = {} + + # Validation phase tracking. When a candidate passes initial checks the + # SDK fires validation sub-iterations (val_iter = main_iter + 1, +2, …). + # These are internal cross-checks and should NOT create separate records; + # instead they are folded back into the parent main-loop iteration's record. + _in_validation_phase: bool = False + _validation_parent_iteration: int = -1 + + # Tracks the most recently opened (POSTed) iteration so we can close it + # with a RUNNING:COMPLETED patch when the next iteration begins. Without + # this, iterations that don't naturally receive a terminal event (e.g. the + # init iteration 0, or non-final GT samples) are left in a stale state. + _last_open_iteration: int = -1 + + def _resolve_model_config_key(model_name: str) -> str: + if not model_name: + return "" + match = _find_model_config(model_name, _cached_model_configs) + return match["key"] if match else model_name def _persist_and_forward( status: Literal[ @@ -1435,47 +1507,120 @@ def _persist_and_forward( ], ctx: OptimizationContext, ) -> None: + nonlocal _in_validation_phase, _validation_parent_iteration, _last_open_iteration # _safe_status_update (the caller) already wraps this entire function in # a try/except, so errors here are caught and logged without aborting the run. mapped = _OPTIMIZATION_STATUS_MAP.get( status, {"status": "RUNNING", "activity": "PENDING"} ) snapshot = ctx.copy_without_history() - payload: OptimizationResultPayload = { - "run_id": run_id, - "config_optimization_version": config_version, - "status": mapped["status"], - "activity": mapped["activity"], - "iteration": snapshot.iteration, - "instructions": snapshot.current_instructions, - "parameters": snapshot.current_parameters, - "completion_response": snapshot.completion_response, - "scores": {k: v.to_json() for k, v in snapshot.scores.items()}, - "user_input": snapshot.user_input, - } - if snapshot.duration_ms is not None: - payload["generation_latency"] = snapshot.duration_ms - if snapshot.usage is not None: - payload["generation_tokens"] = { - "total": snapshot.usage.total, - "input": snapshot.usage.input, - "output": snapshot.usage.output, + + # "validating" fires with the parent main-loop iteration's context, so + # we capture that number as the anchor for all subsequent validation events. + if status == "validating": + _in_validation_phase = True + _validation_parent_iteration = snapshot.iteration + + # Any event whose ctx.iteration differs from the validation anchor is a + # validation sub-iteration; fold it back to the parent's record. + if _in_validation_phase and snapshot.iteration != _validation_parent_iteration: + logical_iteration = _validation_parent_iteration + else: + logical_iteration = snapshot.iteration + + # When a new iteration begins (generating), close out whatever iteration + # was last open so it doesn't remain in a non-terminal state. This covers + # the init iteration (0 → 1) and GT batches where non-final samples never + # receive an explicit terminal event. + if ( + status == "generating" + and _last_open_iteration >= 0 + and logical_iteration != _last_open_iteration + ): + prev_result_id = _iteration_result_ids.get(_last_open_iteration) + if prev_result_id: + api_client.patch_agent_optimization_result( + project_key, + optimization_key, + prev_result_id, + {"status": "RUNNING", "activity": "COMPLETED"}, + ) + _last_open_iteration = -1 + + # Phase 1: POST to create the record on first encounter of each logical iteration. + if logical_iteration not in _iteration_result_ids: + post_payload: AgentOptimizationResultPost = { + "runId": run_id, + "agentOptimizationVersion": config_version, + "iteration": logical_iteration, + "instructions": snapshot.current_instructions, } - eval_latencies = { - k: v.duration_ms - for k, v in snapshot.scores.items() - if v.duration_ms is not None - } - if eval_latencies: - payload["evaluation_latencies"] = eval_latencies - eval_tokens = { - k: {"total": v.usage.total, "input": v.usage.input, "output": v.usage.output} - for k, v in snapshot.scores.items() - if v.usage is not None - } - if eval_tokens: - payload["evaluation_tokens"] = eval_tokens - api_client.post_agent_optimization_result(project_key, optimization_id, payload) + if snapshot.current_parameters: + post_payload["parameters"] = snapshot.current_parameters + if snapshot.user_input: + post_payload["userInput"] = snapshot.user_input + result_id = api_client.post_agent_optimization_result( + project_key, optimization_key, post_payload + ) + if result_id: + _iteration_result_ids[logical_iteration] = result_id + self._last_optimization_result_id = result_id + _last_open_iteration = logical_iteration + + # Phase 2: PATCH the record with current status and available telemetry. + result_id = _iteration_result_ids.get(logical_iteration) + if result_id: + patch: AgentOptimizationResultPatch = { + "status": mapped["status"], + "activity": mapped["activity"], + } + if snapshot.completion_response: + patch["completionResponse"] = snapshot.completion_response + if snapshot.scores: + patch["scores"] = { + k: { + **v.to_json(), + **({"threshold": judges[k].threshold} if k in judges else {}), + } + for k, v in snapshot.scores.items() + } + if snapshot.duration_ms is not None: + patch["generationLatency"] = int(snapshot.duration_ms) + if snapshot.usage is not None: + patch["generationTokens"] = { + "total": snapshot.usage.total, + "input": snapshot.usage.input, + "output": snapshot.usage.output, + } + eval_latencies = { + k: v.duration_ms + for k, v in snapshot.scores.items() + if v.duration_ms is not None + } + if eval_latencies: + patch["evaluationLatencies"] = eval_latencies + eval_tokens = { + k: {"total": v.usage.total, "input": v.usage.input, "output": v.usage.output} + for k, v in snapshot.scores.items() + if v.usage is not None + } + if eval_tokens: + patch["evaluationTokens"] = eval_tokens + patch["variation"] = { + "instructions": snapshot.current_instructions, + "parameters": snapshot.current_parameters, + "modelConfigKey": _resolve_model_config_key(snapshot.current_model or ""), + } + api_client.patch_agent_optimization_result( + project_key, optimization_key, result_id, patch + ) + + # Reset tracking state after terminal events so the next main-loop + # attempt starts fresh. + if status in ("turn completed", "success", "failure"): + _in_validation_phase = False + _validation_parent_iteration = -1 + _last_open_iteration = -1 if options.on_status_update: try: @@ -1590,7 +1735,6 @@ async def _execute_agent_turn( scores: Dict[str, JudgeResult] = {} if self._options.judges: - self._safe_status_update("evaluating", optimize_context, iteration) agent_tools = self._extract_agent_tools(optimize_context.current_parameters) scores = await self._call_judges( completion_response, @@ -1602,7 +1746,11 @@ async def _execute_agent_turn( agent_duration_ms=agent_duration_ms, ) - return dataclasses.replace( + # Build the fully-populated result context before firing the evaluating event so + # the PATCH includes scores, generationLatency, and completionResponse. This is + # particularly important for non-final GT samples which receive no further status + # events — without this, those fields would never be written to their API records. + result_ctx = dataclasses.replace( optimize_context, completion_response=completion_response, scores=scores, @@ -1610,6 +1758,11 @@ async def _execute_agent_turn( usage=agent_response.usage, ) + if self._options.judges: + self._safe_status_update("evaluating", result_ctx, iteration) + + return result_ctx + def _evaluate_response(self, optimize_context: OptimizationContext) -> bool: """ Determine whether the current iteration's scores meet all judge thresholds. @@ -1731,6 +1884,7 @@ def _commit_variation( output_key: Optional[str], api_client: Optional[LDApiClient] = None, base_url: Optional[str] = None, + model_configs: Optional[List[Dict[str, Any]]] = None, ) -> str: """Commit the winning optimization context as a new AI Config variation. @@ -1774,8 +1928,8 @@ def _commit_variation( model_name = optimize_context.current_model or "" model_config_key = model_name # fallback if lookup fails try: - model_configs = api_client.get_model_configs(project_key) - match = next((mc for mc in model_configs if mc.get("id") == model_name), None) + configs_to_search = model_configs if model_configs is not None else api_client.get_model_configs(project_key) + match = _find_model_config(model_name, configs_to_search) if match: model_config_key = match["key"] else: @@ -1792,6 +1946,8 @@ def _commit_variation( "instructions": optimize_context.current_instructions, "modelConfigKey": model_config_key, } + if self._initial_tool_keys: + payload["toolKeys"] = list(self._initial_tool_keys) last_exc: Optional[Exception] = None for attempt in range(1, 4): @@ -1971,6 +2127,7 @@ async def _run_optimization( self._agent_config = agent_config self._last_run_succeeded = False self._last_succeeded_context = None + self._last_optimization_result_id = None self._initialize_class_members_from_config(agent_config) # If the LD flag doesn't carry a model name, seed from the first model choice @@ -2052,8 +2209,12 @@ async def _run_optimization( optimize_context, iteration ) if all_valid: - return self._handle_success(last_ctx, iteration) - # Validation failed — treat as a normal failed attempt + return self._handle_success(optimize_context, iteration) + # Validation failed — treat as a normal failed attempt. + # Use optimize_context (the main iteration) for terminal API events so + # the persisted record's completionResponse and userInput stay aligned. + # last_ctx (the failing validation run) goes into history so the + # variation generator can see what went wrong. logger.info( "[Iteration %d] -> Validation failed — generating new variation (attempt %d/%d)", iteration, @@ -2061,7 +2222,7 @@ async def _run_optimization( self._options.max_attempts, ) if iteration >= self._options.max_attempts: - return self._handle_failure(last_ctx, iteration) + return self._handle_failure(optimize_context, iteration) self._history.append(last_ctx) try: await self._generate_new_variation( @@ -2071,8 +2232,8 @@ async def _run_optimization( logger.exception( "[Iteration %d] -> variation generation failed", iteration ) - return self._handle_failure(last_ctx, iteration) - self._safe_status_update("turn completed", last_ctx, iteration) + return self._handle_failure(optimize_context, iteration) + self._safe_status_update("turn completed", optimize_context, iteration) continue # Initial turn failed diff --git a/packages/optimization/src/ldai_optimization/dataclasses.py b/packages/optimization/src/ldai_optimization/dataclasses.py index edcbd8b2..f4d2f91c 100644 --- a/packages/optimization/src/ldai_optimization/dataclasses.py +++ b/packages/optimization/src/ldai_optimization/dataclasses.py @@ -14,6 +14,7 @@ Sequence, Union, ) +from typing_extensions import Protocol from ldai import AIAgentConfig from ldai.models import LDMessage, ModelConfig @@ -108,6 +109,45 @@ def from_dict(cls, data: Dict[str, Any]) -> "ToolDefinition": ) +class LLMCallConfig(Protocol): + """Structural protocol satisfied by both ``AIAgentConfig`` and ``AIJudgeCallConfig``. + + Use this as the config parameter type when you want a single handler function + that can be passed to both ``handle_agent_call`` and ``handle_judge_call``:: + + async def handle_llm_call( + key: str, + config: LLMCallConfig, + context: LLMCallContext, + ) -> OptimizationResponse: + model_name = config.model.name if config.model else "gpt-4o" + instructions = config.instructions or "" + tools = config.model.get_parameter("tools") if config.model else [] + ... + + OptimizationOptions( + handle_agent_call=handle_llm_call, + handle_judge_call=handle_llm_call, + ... + ) + """ + + key: str + model: Optional[ModelConfig] + instructions: Optional[str] + + +class LLMCallContext(Protocol): + """Structural protocol satisfied by both ``OptimizationContext`` and ``OptimizationJudgeContext``. + + Use alongside ``LLMCallConfig`` when writing a single handler for both + ``handle_agent_call`` and ``handle_judge_call``. + """ + + user_input: Optional[str] + current_variables: Dict[str, Any] + + @dataclass class AIJudgeCallConfig: """ @@ -229,20 +269,25 @@ class OptimizationJudgeContext: """Context for a single judge evaluation turn.""" user_input: str # the agent response being evaluated - variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation + current_variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation # Shared callback type aliases used by both OptimizationOptions and # OptimizationFromConfigOptions to avoid duplicating the full signatures. # Placed here so all referenced types (OptimizationContext, AIJudgeCallConfig, # OptimizationJudgeContext) are already defined above. +# +# Both aliases use the LLMCallConfig / LLMCallContext Protocols so callers can +# write a single handler for both agent and judge calls. Handlers typed with +# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work +# because those types structurally satisfy the Protocols. HandleAgentCall = Union[ - Callable[[str, AIAgentConfig, OptimizationContext], OptimizationResponse], - Callable[[str, AIAgentConfig, OptimizationContext], Awaitable[OptimizationResponse]], + Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse], + Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]], ] HandleJudgeCall = Union[ - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], OptimizationResponse], - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], Awaitable[OptimizationResponse]], + Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse], + Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]], ] _StatusLiteral = Literal[ @@ -261,9 +306,7 @@ class OptimizationJudgeContext: class OptimizationOptions: """Options for agent optimization.""" - # Required - context_choices: List[Context] # choices of contexts to be used, 1 min required - # Configuration + # Configuration - Required max_attempts: int model_choices: List[str] # model ids the LLM can choose from, 1 min required judge_model: str # which model to use as judge; this should remain consistent @@ -283,6 +326,10 @@ class OptimizationOptions: on_turn: Optional[Callable[[OptimizationContext], bool]] = ( None # if you want manual control of pass/fail ) + # Context - Optional; defaults to a single anonymous context + context_choices: List[Context] = field( + default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()] + ) # Auto-commit - Optional auto_commit: bool = False project_key: Optional[str] = None # required when auto_commit=True @@ -295,8 +342,6 @@ class OptimizationOptions: def __post_init__(self): """Validate required options.""" - if len(self.context_choices) < 1: - raise ValueError("context_choices must have at least 1 context") if len(self.model_choices) < 1: raise ValueError("model_choices must have at least 1 model") if self.judges is None and self.on_turn is None: @@ -351,7 +396,6 @@ class GroundTruthOptimizationOptions: :param on_status_update: Called on each status transition during the run. """ - context_choices: List[Context] ground_truth_responses: List[GroundTruthSample] max_attempts: int model_choices: List[str] @@ -372,6 +416,10 @@ class GroundTruthOptimizationOptions: None, ] ] = None + # Context - Optional; defaults to a single anonymous context + context_choices: List[Context] = field( + default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()] + ) # Auto-commit - Optional auto_commit: bool = False project_key: Optional[str] = None # required when auto_commit=True @@ -380,8 +428,6 @@ class GroundTruthOptimizationOptions: def __post_init__(self): """Validate required options.""" - if len(self.context_choices) < 1: - raise ValueError("context_choices must have at least 1 context") if len(self.model_choices) < 1: raise ValueError("model_choices must have at least 1 model") if len(self.ground_truth_responses) < 1: @@ -414,7 +460,6 @@ class OptimizationFromConfigOptions: """ project_key: str - context_choices: List[Context] handle_agent_call: HandleAgentCall handle_judge_call: HandleJudgeCall on_turn: Optional[Callable[["OptimizationContext"], bool]] = None @@ -422,12 +467,11 @@ class OptimizationFromConfigOptions: on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None on_failing_result: Optional[Callable[["OptimizationContext"], None]] = None on_status_update: Optional[Callable[[_StatusLiteral, "OptimizationContext"], None]] = None + # Context - Optional; defaults to a single anonymous context + context_choices: List[Context] = field( + default_factory=lambda: [Context.builder("anonymous").anonymous(True).build()] + ) base_url: Optional[str] = None # Auto-commit defaults to True for config-driven runs; set False to disable auto_commit: bool = True output_key: Optional[str] = None # variation key/name; auto-generated if omitted - - def __post_init__(self): - """Validate required options.""" - if len(self.context_choices) < 1: - raise ValueError("context_choices must have at least 1 context") diff --git a/packages/optimization/src/ldai_optimization/ld_api_client.py b/packages/optimization/src/ldai_optimization/ld_api_client.py index a8076298..6671bffc 100644 --- a/packages/optimization/src/ldai_optimization/ld_api_client.py +++ b/packages/optimization/src/ldai_optimization/ld_api_client.py @@ -80,37 +80,36 @@ class AgentOptimizationConfig(_AgentOptimizationConfigRequired, total=False): # --------------------------------------------------------------------------- -# POST payload shape +# Result payload shapes # --------------------------------------------------------------------------- -class _OptimizationResultPayloadRequired(TypedDict): - run_id: str - config_optimization_version: int - status: str - activity: str +class _AgentOptimizationResultPostRequired(TypedDict): + runId: str + agentOptimizationVersion: int iteration: int instructions: str - parameters: Dict[str, Any] - completion_response: str - scores: Dict[str, Any] -class OptimizationResultPayload(_OptimizationResultPayloadRequired, total=False): - """Typed payload for a single agent_optimization_result POST request. +class AgentOptimizationResultPost(_AgentOptimizationResultPostRequired, total=False): + """Payload for POST /agent-optimizations/{key}/results — creates a new result record.""" - Required fields are always sent. Optional fields are omitted when not - available. + userInput: str + parameters: Dict[str, Any] - created_variation_key is only present on the final result record of a - successful run, populated once a winning variation is committed to LD. - """ - user_input: Optional[str] - created_variation_key: str - generation_latency: float - generation_tokens: Dict[str, int] - evaluation_latencies: Dict[str, float] - evaluation_tokens: Dict[str, Dict[str, int]] +class AgentOptimizationResultPatch(TypedDict, total=False): + """Payload for PATCH /agent-optimizations/{key}/results/{id} — updates a result record.""" + + status: str + activity: str + completionResponse: str + scores: Dict[str, Any] + generationLatency: int + generationTokens: Dict[str, int] + evaluationLatencies: Dict[str, float] + evaluationTokens: Dict[str, Dict[str, int]] + variation: Dict[str, Any] + createdVariationKey: str # --------------------------------------------------------------------------- @@ -325,31 +324,64 @@ def get_agent_optimization( return _parse_agent_optimization(raw) def post_agent_optimization_result( - self, project_key: str, optimization_id: str, payload: OptimizationResultPayload - ) -> None: - """Persist an iteration result record for the given optimization run. + self, project_key: str, optimization_key: str, payload: AgentOptimizationResultPost + ) -> Optional[str]: + """Create an iteration result record for the given optimization run. Errors are caught and logged rather than raised so that persistence failures never abort an in-progress optimization run. :param project_key: LaunchDarkly project key. - :param optimization_id: UUID id of the parent agent_optimization record. - :param payload: Typed result payload for this iteration. + :param optimization_key: String key of the parent agent_optimization record. + :param payload: POST payload for this iteration. + :return: The ``id`` of the newly created result record, or None on failure. """ - path = f"/api/v2/projects/{project_key}/agent-optimizations/{optimization_id}/results" + path = f"/api/v2/projects/{project_key}/agent-optimizations/{optimization_key}/results" try: - self._request("POST", path, body=payload) + result = self._request("POST", path, body=payload) + return result.get("id") if isinstance(result, dict) else None except LDApiError as exc: logger.debug( - "Failed to persist optimization result (optimization_id=%s, iteration=%s): %s", - optimization_id, + "Failed to persist optimization result (optimization_key=%s, iteration=%s): %s", + optimization_key, payload.get("iteration"), exc, ) + return None except Exception as exc: logger.debug( - "Unexpected error persisting optimization result (optimization_id=%s, iteration=%s): %s", - optimization_id, + "Unexpected error persisting optimization result (optimization_key=%s, iteration=%s): %s", + optimization_key, payload.get("iteration"), exc, ) + return None + + def patch_agent_optimization_result( + self, project_key: str, optimization_key: str, result_id: str, payload: AgentOptimizationResultPatch + ) -> None: + """Update an existing iteration result record. + + Errors are caught and logged rather than raised so that persistence + failures never abort an in-progress optimization run. + + :param project_key: LaunchDarkly project key. + :param optimization_key: String key of the parent agent_optimization record. + :param result_id: ID of the result record to update. + :param payload: PATCH payload with fields to update. + """ + path = f"/api/v2/projects/{project_key}/agent-optimizations/{optimization_key}/results/{result_id}" + try: + self._request("PATCH", path, body=payload) + except LDApiError as exc: + logger.debug( + "Failed to update optimization result (result_id=%s): %s", + result_id, + exc, + ) + except Exception as exc: + logger.debug( + "Unexpected error updating optimization result (result_id=%s): %s", + result_id, + exc, + ) diff --git a/packages/optimization/tests/test_client.py b/packages/optimization/tests/test_client.py index c88cef8b..39d75146 100644 --- a/packages/optimization/tests/test_client.py +++ b/packages/optimization/tests/test_client.py @@ -9,7 +9,7 @@ from ldai.models import LDMessage, ModelConfig from ldclient import Context -from ldai_optimization.client import OptimizationClient, _compute_validation_count +from ldai_optimization.client import OptimizationClient, _compute_validation_count, _find_model_config from ldai_optimization.dataclasses import ( AIJudgeCallConfig, GroundTruthOptimizationOptions, @@ -157,6 +157,59 @@ def test_result_is_valid_json_string(self): json.loads(result) +# --------------------------------------------------------------------------- +# _find_model_config +# --------------------------------------------------------------------------- + + +class TestFindModelConfig: + def test_returns_none_when_no_configs(self): + assert _find_model_config("gpt-4o", []) is None + + def test_returns_none_when_no_id_match(self): + configs = [{"id": "claude-3", "key": "Anthropic.claude-3", "global": True}] + assert _find_model_config("gpt-4o", configs) is None + + def test_returns_single_match(self): + configs = [{"id": "gpt-4o", "key": "OpenAI.gpt-4o", "global": False}] + result = _find_model_config("gpt-4o", configs) + assert result is not None + assert result["key"] == "OpenAI.gpt-4o" + + def test_prefers_global_match_over_non_global(self): + configs = [ + {"id": "gpt-4o", "key": "project.gpt-4o", "global": False}, + {"id": "gpt-4o", "key": "global.gpt-4o", "global": True}, + ] + result = _find_model_config("gpt-4o", configs) + assert result is not None + assert result["key"] == "global.gpt-4o" + + def test_prefers_global_match_regardless_of_list_order(self): + configs = [ + {"id": "gpt-4o", "key": "global.gpt-4o", "global": True}, + {"id": "gpt-4o", "key": "project.gpt-4o", "global": False}, + ] + result = _find_model_config("gpt-4o", configs) + assert result["key"] == "global.gpt-4o" + + def test_falls_back_to_non_global_when_no_global_exists(self): + configs = [ + {"id": "gpt-4o", "key": "project.gpt-4o", "global": False}, + ] + result = _find_model_config("gpt-4o", configs) + assert result is not None + assert result["key"] == "project.gpt-4o" + + def test_treats_missing_global_field_as_non_global(self): + configs = [ + {"id": "gpt-4o", "key": "no-global-field.gpt-4o"}, + {"id": "gpt-4o", "key": "global.gpt-4o", "global": True}, + ] + result = _find_model_config("gpt-4o", configs) + assert result["key"] == "global.gpt-4o" + + # --------------------------------------------------------------------------- # _extract_agent_tools # --------------------------------------------------------------------------- @@ -460,7 +513,7 @@ async def test_variables_in_context(self): ) call_args = self.handle_judge_call.call_args _, _, ctx = call_args.args - assert ctx.variables == variables + assert ctx.current_variables == variables async def test_duration_context_added_to_instructions_when_latency_keyword_present(self): """When acceptance statement has a latency keyword and agent_duration_ms is provided, @@ -1048,6 +1101,24 @@ async def test_on_turn_manual_path_success(self): result = await client.optimize_from_options("test-agent", options) assert result.completion_response == "Answer." + async def test_success_result_carries_main_iteration_context_not_validation_context(self): + # The main iteration returns "Main answer." but the validation run returns + # "Validation answer.". The result should reflect the main iteration so that + # completion_response and user_input are consistent with what was POSTed to the API. + agent_responses = [ + OptimizationResponse(output="Main answer."), # main iteration + OptimizationResponse(output="Validation answer."), # validation sample + ] + handle_agent_call = AsyncMock(side_effect=agent_responses) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) + client = _make_client(self.mock_ldai) + options = _make_options( + handle_agent_call=handle_agent_call, + handle_judge_call=handle_judge_call, + ) + result = await client.optimize_from_options("test-agent", options) + assert result.completion_response == "Main answer." + async def test_status_update_callback_called_at_each_stage(self): statuses = [] handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="Good answer.")) @@ -1308,6 +1379,69 @@ async def test_validating_status_emitted(self): await client.optimize_from_options("test-agent", opts) assert "validating" in statuses + async def test_turn_completed_after_validation_failure_uses_main_iteration_context(self): + """When validation fails, the 'turn completed' event must carry the MAIN iteration's + user_input and completion_response — not the failing validation sample's values. + + Regression test for the mismatch where a record stored userInput='hostel near paris' + but completionResponse described 'airbmbs near tahoe' (from a validation run with a + different user_input that was folded back onto the main iteration's API record). + """ + turn_calls = [0] + status_events: list = [] + + user_inputs = [f"query-{i}" for i in range(8)] + + def on_turn(ctx): + turn_calls[0] += 1 + # Call 1: main iteration passes. Call 2: first validation sample FAILS. + # Call 3+: everything passes (new attempt succeeds). + return turn_calls[0] != 2 + + def capture_status(status, ctx): + status_events.append((status, ctx.user_input, ctx.completion_response)) + + client = self._make_client() + opts = _make_multi_options( + on_turn=on_turn, + variable_count=8, + user_input_options=user_inputs, + handle_agent_call=AsyncMock(side_effect=[ + OptimizationResponse(output="main-response"), # main turn (passes) + OptimizationResponse(output="val-response"), # validation sample (fails) + OptimizationResponse(output=VARIATION_RESPONSE), # variation generation + OptimizationResponse(output="main-response-2"), # 2nd attempt main (passes) + OptimizationResponse(output="val-response-2"), # 2nd attempt validation (passes) + OptimizationResponse(output="val-response-3"), # 2nd attempt validation (passes) + ]), + max_attempts=3, + ) + opts.on_status_update = capture_status + await client.optimize_from_options("test-agent", opts) + + # The 'generating' event captures the main iteration's user_input. + # The validation run fires 'generating' as well, but with a different user_input. + # The first 'generating' is always the main iteration. + generating_events = [(u, r) for s, u, r in status_events if s == "generating"] + main_user_input = generating_events[0][0] + + # Find the 'turn completed' event from the first attempt (after validation failure) + tc_events = [(u, r) for s, u, r in status_events if s == "turn completed"] + assert len(tc_events) >= 1, "Expected at least one 'turn completed' event" + + tc_user_input, tc_completion = tc_events[0] + # turn completed must use the MAIN iteration's data, not the validation sample's. + # If the bug is present, tc_completion would be "val-response" and tc_user_input + # would be the validation sample's query (different from main_user_input). + assert tc_completion == "main-response", ( + f"turn completed should carry the main iteration's completion_response " + f"('main-response'), not the validation run's (got: {tc_completion!r})" + ) + assert tc_user_input == main_user_input, ( + f"turn completed should carry the main iteration's user_input " + f"('{main_user_input}'), not the validation run's (got: {tc_user_input!r})" + ) + # --------------------------------------------------------------------------- # Variation prompt — acceptance criteria section @@ -1747,7 +1881,9 @@ def _make_from_config_options(**overrides: Any) -> OptimizationFromConfigOptions def _make_mock_api_client() -> MagicMock: mock = MagicMock() - mock.post_agent_optimization_result = MagicMock() + mock.post_agent_optimization_result = MagicMock(return_value="result-uuid-789") + mock.patch_agent_optimization_result = MagicMock() + mock.get_model_configs = MagicMock(return_value=[]) return mock @@ -1769,8 +1905,9 @@ def _build(self, config=None, options=None) -> OptimizationOptions: config or dict(_API_CONFIG), options or _make_from_config_options(), self.api_client, - optimization_id="opt-uuid-123", + optimization_key="opt-key-123", run_id="run-uuid-456", + model_configs=[], ) def test_acceptance_statements_mapped_to_judges(self): @@ -1904,7 +2041,7 @@ def test_persist_and_forward_posts_result_on_status_update(self): self.api_client.post_agent_optimization_result.assert_called_once() call_args = self.api_client.post_agent_optimization_result.call_args assert call_args[0][0] == "my-project" - assert call_args[0][1] == "opt-uuid-123" + assert call_args[0][1] == "opt-key-123" def test_persist_and_forward_payload_has_correct_field_names(self): result = self._build() @@ -1919,13 +2056,51 @@ def test_persist_and_forward_payload_has_correct_field_names(self): iteration=2, ) result.on_status_update("evaluating", ctx) - payload = self.api_client.post_agent_optimization_result.call_args[0][2] - assert payload["instructions"] == "Be helpful." - assert payload["parameters"] == {"temperature": 0.5} - assert payload["completion_response"] == "Paris." - assert payload["user_input"] == "Capital of France?" - assert payload["iteration"] == 2 - assert "j" in payload["scores"] + # POST payload contains the camelCase iteration-level fields + post_payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert post_payload["instructions"] == "Be helpful." + assert post_payload["parameters"] == {"temperature": 0.5} + assert post_payload["userInput"] == "Capital of France?" + assert post_payload["iteration"] == 2 + # Telemetry and scores are in the PATCH payload + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert patch_payload["completionResponse"] == "Paris." + assert "j" in patch_payload["scores"] + + def test_persist_and_forward_scores_include_threshold_for_known_judges(self): + # Build with a config that has a known acceptance-statement judge (threshold=0.9) + result = self._build() + ctx = OptimizationContext( + scores={"acceptance-statement-0": JudgeResult(score=0.85, rationale="Close.")}, + completion_response="An answer.", + current_instructions="Be helpful.", + current_parameters={}, + current_variables={}, + iteration=1, + ) + result.on_status_update("evaluating", ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + score_entry = patch_payload["scores"]["acceptance-statement-0"] + assert score_entry["score"] == 0.85 + assert score_entry["rationale"] == "Close." + assert score_entry["threshold"] == 0.9 + + def test_persist_and_forward_scores_omit_threshold_for_unknown_judge_key(self): + # A score whose key doesn't match any configured judge should not include threshold + result = self._build() + ctx = OptimizationContext( + scores={"unknown-judge": JudgeResult(score=0.5, rationale="Unknown.")}, + completion_response="Answer.", + current_instructions="", + current_parameters={}, + current_variables={}, + iteration=1, + ) + result.on_status_update("evaluating", ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + score_entry = patch_payload["scores"]["unknown-judge"] + assert score_entry["score"] == 0.5 + assert "threshold" not in score_entry def test_persist_and_forward_includes_run_id_and_version(self): result = self._build() @@ -1934,15 +2109,43 @@ def test_persist_and_forward_includes_run_id_and_version(self): current_parameters={}, current_variables={}, iteration=1, ) result.on_status_update("generating", ctx) - payload = self.api_client.post_agent_optimization_result.call_args[0][2] - assert payload["run_id"] == "run-uuid-456" - assert payload["config_optimization_version"] == 2 + post_payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert post_payload["runId"] == "run-uuid-456" + assert post_payload["agentOptimizationVersion"] == 2 + + def test_second_call_same_iteration_does_not_post_again(self): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) + result.on_status_update("evaluating", ctx) + # POST is called only once (first encounter of iteration 1) + assert self.api_client.post_agent_optimization_result.call_count == 1 + # PATCH is called twice + assert self.api_client.patch_agent_optimization_result.call_count == 2 + + def test_each_new_iteration_posts_a_new_record(self): + result = self._build() + ctx1 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + ctx2 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=2, + ) + result.on_status_update("generating", ctx1) + result.on_status_update("generating", ctx2) + assert self.api_client.post_agent_optimization_result.call_count == 2 @pytest.mark.parametrize("sdk_status,expected_status,expected_activity", [ ("init", "RUNNING", "PENDING"), ("generating", "RUNNING", "GENERATING"), ("evaluating", "RUNNING", "EVALUATING"), ("generating variation", "RUNNING", "GENERATING_VARIATION"), + ("validating", "RUNNING", "EVALUATING"), ("turn completed", "RUNNING", "COMPLETED"), ("success", "PASSED", "COMPLETED"), ("failure", "FAILED", "COMPLETED"), @@ -1954,14 +2157,18 @@ def test_status_mapping(self, sdk_status, expected_status, expected_activity): current_parameters={}, current_variables={}, iteration=1, ) result.on_status_update(sdk_status, ctx) - payload = self.api_client.post_agent_optimization_result.call_args[0][2] - assert payload["status"] == expected_status - assert payload["activity"] == expected_activity + # status and activity are in the PATCH payload, not the POST payload + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert patch_payload["status"] == expected_status + assert patch_payload["activity"] == expected_activity - def test_user_on_status_update_chained_after_post(self): + def test_user_on_status_update_chained_after_post_and_patch(self): call_order = [] self.api_client.post_agent_optimization_result.side_effect = ( - lambda *a, **kw: call_order.append("post") + lambda *a, **kw: call_order.append("post") or "result-id" + ) + self.api_client.patch_agent_optimization_result.side_effect = ( + lambda *a, **kw: call_order.append("patch") ) user_cb = MagicMock(side_effect=lambda s, c: call_order.append("user")) options = _make_from_config_options(on_status_update=user_cb) @@ -1971,7 +2178,7 @@ def test_user_on_status_update_chained_after_post(self): current_parameters={}, current_variables={}, iteration=1, ) result.on_status_update("generating", ctx) - assert call_order == ["post", "user"] + assert call_order == ["post", "patch", "user"] def test_user_on_status_update_exception_does_not_propagate(self): options = _make_from_config_options( @@ -1984,15 +2191,267 @@ def test_user_on_status_update_exception_does_not_propagate(self): ) result.on_status_update("generating", ctx) # must not raise - def test_payload_history_not_included(self): + def test_post_payload_does_not_contain_history(self): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) + post_payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert "history" not in post_payload + + @pytest.mark.parametrize("status", [ + "init", "generating", "evaluating", "generating variation", + "validating", "turn completed", "success", "failure", + ]) + def test_variation_included_in_patch_for_all_statuses(self, status): + result = self._build() + ctx = OptimizationContext( + scores={}, + completion_response="answer", + current_instructions="Be concise.", + current_parameters={"temperature": 0.3}, + current_variables={}, + current_model="gpt-4o", + iteration=1, + ) + result.on_status_update(status, ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert "variation" in patch_payload + assert patch_payload["variation"]["instructions"] == "Be concise." + assert patch_payload["variation"]["parameters"] == {"temperature": 0.3} + + @pytest.mark.parametrize("status", ["generating", "evaluating", "success"]) + def test_model_config_key_prefers_global_in_variation(self, status): + model_configs = [ + {"id": "gpt-4o", "key": "project.gpt-4o", "global": False}, + {"id": "gpt-4o", "key": "global.gpt-4o", "global": True}, + ] + result = self.client._build_options_from_config( + dict(_API_CONFIG), + _make_from_config_options(), + self.api_client, + optimization_key="opt-key-123", + run_id="run-uuid-456", + model_configs=model_configs, + ) + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="instr", + current_parameters={}, current_variables={}, current_model="gpt-4o", + iteration=1, + ) + result.on_status_update(status, ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert patch_payload["variation"]["modelConfigKey"] == "global.gpt-4o" + + @pytest.mark.parametrize("status", ["generating", "evaluating", "success"]) + def test_model_config_key_resolved_in_variation(self, status): + model_configs = [{"id": "gpt-4o", "key": "OpenAI.gpt-4o"}] + result = self.client._build_options_from_config( + dict(_API_CONFIG), + _make_from_config_options(), + self.api_client, + optimization_key="opt-key-123", + run_id="run-uuid-456", + model_configs=model_configs, + ) + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="instr", + current_parameters={}, current_variables={}, current_model="gpt-4o", + iteration=1, + ) + result.on_status_update(status, ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert patch_payload["variation"]["modelConfigKey"] == "OpenAI.gpt-4o" + + def test_generation_latency_cast_to_int(self): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, duration_ms=123.7, + iteration=1, + ) + result.on_status_update("generating", ctx) + patch_payload = self.api_client.patch_agent_optimization_result.call_args[0][3] + assert patch_payload["generationLatency"] == 123 + assert isinstance(patch_payload["generationLatency"], int) + + def test_last_optimization_result_id_updated_on_post(self): result = self._build() ctx = OptimizationContext( scores={}, completion_response="", current_instructions="", current_parameters={}, current_variables={}, iteration=1, ) result.on_status_update("generating", ctx) - payload = self.api_client.post_agent_optimization_result.call_args[0][2] - assert "history" not in payload + assert self.client._last_optimization_result_id == "result-uuid-789" + + def test_validation_sub_iterations_do_not_create_new_records(self): + """Validation sub-iterations should be folded into the parent iteration's record.""" + result = self._build() + ctx_main = OptimizationContext( + scores={}, completion_response="a", current_instructions="i", + current_parameters={}, current_variables={}, iteration=1, + ) + ctx_val1 = OptimizationContext( + scores={}, completion_response="b", current_instructions="i", + current_parameters={}, current_variables={}, iteration=2, + ) + ctx_val2 = OptimizationContext( + scores={}, completion_response="c", current_instructions="i", + current_parameters={}, current_variables={}, iteration=3, + ) + result.on_status_update("generating", ctx_main) # POST iter 1 + result.on_status_update("evaluating", ctx_main) # PATCH iter 1 + result.on_status_update("validating", ctx_main) # enter validation; PATCH iter 1 + result.on_status_update("generating", ctx_val1) # validation sub-iter → folded to iter 1 + result.on_status_update("evaluating", ctx_val1) # folded to iter 1 + result.on_status_update("generating", ctx_val2) # validation sub-iter → folded to iter 1 + result.on_status_update("evaluating", ctx_val2) # folded to iter 1 + result.on_status_update("success", ctx_val2) # folded to iter 1; reset validation + + # Only one POST for the single main iteration + assert self.api_client.post_agent_optimization_result.call_count == 1 + post_payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert post_payload["iteration"] == 1 + + def test_validation_success_patches_parent_iteration_record(self): + """success event during validation should PATCH the main iteration's record, not a new one.""" + result = self._build() + ctx_main = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=2, + ) + ctx_val = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=3, + ) + result.on_status_update("generating", ctx_main) + result.on_status_update("validating", ctx_main) + result.on_status_update("generating", ctx_val) + result.on_status_update("success", ctx_val) + + # PATCH for success should use the result_id of the parent (iter 2) record + patch_calls = self.api_client.patch_agent_optimization_result.call_args_list + success_patch = next( + c for c in patch_calls if c[0][3].get("status") == "PASSED" + ) + # Third positional arg is result_id — it should be the one returned from the POST for iter 2 + assert success_patch[0][2] == "result-uuid-789" + + def test_validation_phase_resets_after_turn_completed(self): + """After turn completed, subsequent main-loop iterations create their own records.""" + result = self._build() + ctx1 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + ctx_val = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=2, + ) + ctx2 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=2, + ) + result.on_status_update("generating", ctx1) # POST iter 1 + result.on_status_update("validating", ctx1) # enter validation + result.on_status_update("generating", ctx_val) # folded to iter 1 + result.on_status_update("turn completed", ctx_val) # reset validation phase + result.on_status_update("generating", ctx2) # POST iter 2 (new main attempt) + + assert self.api_client.post_agent_optimization_result.call_count == 2 + + def test_init_iteration_closed_when_first_real_iteration_begins(self): + """The init record (iter 0) must receive a RUNNING:COMPLETED patch before iter 1 starts.""" + result = self._build() + ctx0 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=0, + ) + ctx1 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("init", ctx0) # POST iter 0, PATCH RUNNING:PENDING + result.on_status_update("generating", ctx1) # should close iter 0, then POST iter 1 + + # iter 0 POSTed + iter 1 POSTed + assert self.api_client.post_agent_optimization_result.call_count == 2 + patch_calls = self.api_client.patch_agent_optimization_result.call_args_list + # Patches: (1) init PENDING, (2) auto-close COMPLETED, (3) generating GENERATING + assert len(patch_calls) == 3 + payloads = [c[0][3] for c in patch_calls] + assert payloads[0]["status"] == "RUNNING" + assert payloads[0]["activity"] == "PENDING" + assert "variation" in payloads[0] + assert payloads[1] == {"status": "RUNNING", "activity": "COMPLETED"} # auto-close patch has no variation + assert payloads[2]["status"] == "RUNNING" + assert payloads[2]["activity"] == "GENERATING" + assert "variation" in payloads[2] + + def test_non_final_gt_sample_closed_when_next_sample_begins(self): + """In a GT batch, each sample except the last should receive a RUNNING:COMPLETED patch + when the next sample's generating event fires.""" + result = self._build() + ctx1 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, user_input="What is 2+2?", iteration=1, + ) + ctx2 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, user_input="What is 3+3?", iteration=2, + ) + ctx3 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, user_input="What is 4+4?", iteration=3, + ) + result.on_status_update("generating", ctx1) # POST iter 1 + result.on_status_update("evaluating", ctx1) # PATCH iter 1 (EVALUATING) + result.on_status_update("generating", ctx2) # should auto-close iter 1, then POST iter 2 + result.on_status_update("evaluating", ctx2) # PATCH iter 2 (EVALUATING) + result.on_status_update("generating", ctx3) # should auto-close iter 2, then POST iter 3 + + patch_calls = self.api_client.patch_agent_optimization_result.call_args_list + activities = [c[0][3].get("activity") for c in patch_calls] + # Expected sequence: GENERATING, EVALUATING, COMPLETED (auto-close 1), + # GENERATING, EVALUATING, COMPLETED (auto-close 2), GENERATING + assert activities.count("COMPLETED") >= 2, ( + f"Expected at least 2 COMPLETED patches, got: {activities}" + ) + # The auto-close patches must appear BEFORE the subsequent GENERATING patches + completed_indices = [i for i, a in enumerate(activities) if a == "COMPLETED"] + generating_indices = [i for i, a in enumerate(activities) if a == "GENERATING"] + # Each auto-close patch should precede the next generating patch + assert completed_indices[0] < generating_indices[1] + assert completed_indices[1] < generating_indices[2] + + def test_terminal_event_clears_open_iteration_so_next_generating_does_not_double_close(self): + """After a terminal event (turn completed), the next generating should not try to + close the already-closed iteration again.""" + result = self._build() + ctx1 = OptimizationContext( + scores={}, completion_response="answer", current_instructions="Be helpful.", + current_parameters={}, current_variables={}, iteration=1, + ) + ctx2 = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=2, + ) + result.on_status_update("generating", ctx1) # open iter 1 + result.on_status_update("turn completed", ctx1) # close iter 1 explicitly + result.on_status_update("generating", ctx2) # new iter — should NOT re-close iter 1 + + patch_calls = self.api_client.patch_agent_optimization_result.call_args_list + # The only RUNNING:COMPLETED patch should be from "turn completed", not from the + # auto-close triggered by iter 2's generating event. + completed_patches = [ + c for c in patch_calls + if c[0][3].get("status") == "RUNNING" and c[0][3].get("activity") == "COMPLETED" + ] + assert len(completed_patches) == 1, ( + "Expected exactly one RUNNING:COMPLETED patch (from turn completed), not a duplicate" + ) # --------------------------------------------------------------------------- @@ -2140,10 +2599,6 @@ def test_valid_options_created(self): opts = self._make() assert len(opts.ground_truth_responses) == 1 - def test_raises_empty_context_choices(self): - with pytest.raises(ValueError, match="context_choices"): - self._make(context_choices=[]) - def test_raises_empty_model_choices(self): with pytest.raises(ValueError, match="model_choices"): self._make(model_choices=[]) @@ -2489,8 +2944,9 @@ def _build(self, config=None, options=None): config or dict(_API_CONFIG_WITH_GT), options or _make_from_config_options(), self.api_client, - optimization_id="opt-gt-uuid", + optimization_key="opt-gt-key", run_id="run-uuid-789", + model_configs=[], ) def test_returns_ground_truth_options_when_gt_present(self): @@ -3149,6 +3605,21 @@ def test_model_config_key_falls_back_to_model_name_when_no_id_match(self): payload = api_client.create_ai_config_variation.call_args[0][2] assert payload["modelConfigKey"] == "gpt-4o" + def test_model_config_key_prefers_global_over_non_global(self): + client = self._make_client() + api_client = _make_api_client_for_commit(model_configs=[ + {"id": "gpt-4o", "key": "project.gpt-4o", "global": False}, + {"id": "gpt-4o", "key": "global.gpt-4o", "global": True}, + ]) + + client._commit_variation( + _make_winning_context(model="gpt-4o"), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert payload["modelConfigKey"] == "global.gpt-4o" + def test_model_config_key_falls_back_when_get_model_configs_raises(self): client = self._make_client() api_client = _make_api_client_for_commit() @@ -3234,6 +3705,83 @@ def test_reuses_provided_api_client_without_creating_new_one(self): MockLDApiClient.assert_not_called() + # --- tool key propagation --- + + def test_toolkeys_included_in_payload_when_tools_present(self): + client = self._make_client() + client._initial_tool_keys = ["search-tool", "calculator"] + api_client = _make_api_client_for_commit() + + client._commit_variation( + _make_winning_context(), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert payload["toolKeys"] == ["search-tool", "calculator"] + + def test_toolkeys_not_in_payload_when_no_tools(self): + client = self._make_client() + client._initial_tool_keys = [] + api_client = _make_api_client_for_commit() + + client._commit_variation( + _make_winning_context(), project_key="my-project", + ai_config_key="my-agent", output_key="k", api_client=api_client, + ) + + payload = api_client.create_ai_config_variation.call_args[0][2] + assert "toolKeys" not in payload + + +# --------------------------------------------------------------------------- +# Tool key extraction from raw variation (_get_agent_config) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGetAgentConfigToolKeyExtraction: + def _make_client_with_variation(self, raw_variation: dict) -> OptimizationClient: + mock_ldai = _make_ldai_client() + mock_ldai._client.variation.return_value = raw_variation + return _make_client(mock_ldai) + + async def test_extracts_tool_keys_from_raw_variation(self): + raw = { + "instructions": AGENT_INSTRUCTIONS, + "tools": [ + {"key": "search-tool", "version": 1}, + {"key": "calculator", "version": 2}, + ], + } + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_tool_keys == ["search-tool", "calculator"] + + async def test_initial_tool_keys_empty_when_no_tools_in_variation(self): + raw = {"instructions": AGENT_INSTRUCTIONS} + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_tool_keys == [] + + async def test_initial_tool_keys_empty_when_tools_is_empty_list(self): + raw = {"instructions": AGENT_INSTRUCTIONS, "tools": []} + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_tool_keys == [] + + async def test_skips_tool_entries_without_key(self): + raw = { + "instructions": AGENT_INSTRUCTIONS, + "tools": [ + {"key": "good-tool", "version": 1}, + {"version": 2}, # missing key — should be skipped + ], + } + client = self._make_client_with_variation(raw) + await client._get_agent_config("test-agent", LD_CONTEXT) + assert client._initial_tool_keys == ["good-tool"] + # --------------------------------------------------------------------------- # auto_commit in optimize_from_options @@ -3284,6 +3832,16 @@ async def test_commit_not_called_when_run_fails(self): mock_commit.assert_not_called() + async def test_succeeds_without_api_key_when_auto_commit_false(self): + client = self._make_client_without_key() + options = _make_options() # auto_commit defaults to False + + with patch("ldai_optimization.client.LDApiClient") as mock_api_cls: + result = await client.optimize_from_options("test-agent", options) + + mock_api_cls.assert_not_called() + assert result is not None + async def test_raises_when_auto_commit_true_and_no_api_key(self): client = self._make_client_without_key() options = _make_options(auto_commit=True, project_key="my-project") @@ -3480,3 +4038,54 @@ async def test_output_key_forwarded_to_commit(self): ) assert mock_commit.call_args[1]["output_key"] == "my-variation" + + async def test_model_configs_forwarded_to_commit(self): + """Pre-fetched model configs are passed to _commit_variation to avoid extra API calls.""" + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + mock_api.get_model_configs = MagicMock(return_value=[{"id": "gpt-4o", "key": "OpenAI.gpt-4o"}]) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + with patch.object(client, "_commit_variation") as mock_commit: + await client.optimize_from_config("my-opt", _make_from_config_options()) + + assert mock_commit.call_args[1]["model_configs"] == [{"id": "gpt-4o", "key": "OpenAI.gpt-4o"}] + + async def test_patches_created_variation_key_after_commit(self): + """After _commit_variation succeeds, the last result record is PATCHed with createdVariationKey.""" + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + with patch.object(client, "_commit_variation", return_value="my-new-variation"): + client._last_optimization_result_id = "result-id-abc" + await client.optimize_from_config("my-opt", _make_from_config_options()) + + patch_calls = mock_api.patch_agent_optimization_result.call_args_list + variation_key_patch = next( + (c for c in patch_calls if c[0][3].get("createdVariationKey") == "my-new-variation"), + None, + ) + assert variation_key_patch is not None, "Expected a PATCH with createdVariationKey" + # URL path uses the string key ("my-optimization"), not the UUID ("opt-uuid-123") + assert variation_key_patch[0][1] == "my-optimization" + + async def test_optimization_key_in_post_url_uses_string_key_not_uuid(self): + """post_agent_optimization_result is called with config['key'], not config['id'].""" + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + await client.optimize_from_config("my-opt", _make_from_config_options()) + + post_call_args = mock_api.post_agent_optimization_result.call_args_list + assert len(post_call_args) >= 1 + for call in post_call_args: + opt_key_arg = call[0][1] + # Must use the string key "my-optimization", never the UUID "opt-uuid-123" + assert opt_key_arg == "my-optimization", ( + f"Expected string key 'my-optimization', got '{opt_key_arg}'" + )