From c613fb7306288ffb3f15034a99f7a83550680449 Mon Sep 17 00:00:00 2001 From: Bartolomej Kozorog Date: Fri, 27 Mar 2026 17:41:32 +0100 Subject: [PATCH 1/3] update api client --- .../fine_tuning/04_military_strikes.ipynb | 21 ++--- openapi/openapi.json | 68 ++++++++------- src/lightningrod/__init__.py | 2 + src/lightningrod/_display.py | 7 +- .../_generated/models/__init__.py | 2 + .../models/create_eval_job_request.py | 42 ++++------ .../_generated/models/eval_config.py | 44 ++++------ .../_generated/models/eval_model.py | 83 +++++++++++++++++++ .../training_job_model_id_by_step_type_0.py | 46 ++++++++++ .../models/validation_error_context.py | 46 ++++++++++ src/lightningrod/training/client.py | 2 + src/lightningrod/training/evals.py | 26 ++---- 12 files changed, 270 insertions(+), 119 deletions(-) create mode 100644 src/lightningrod/_generated/models/eval_model.py create mode 100644 src/lightningrod/_generated/models/training_job_model_id_by_step_type_0.py create mode 100644 src/lightningrod/_generated/models/validation_error_context.py diff --git a/notebooks/fine_tuning/04_military_strikes.ipynb b/notebooks/fine_tuning/04_military_strikes.ipynb index 493dca4..209f285 100644 --- a/notebooks/fine_tuning/04_military_strikes.ipynb +++ b/notebooks/fine_tuning/04_military_strikes.ipynb @@ -7,8 +7,7 @@ "# Military Strikes Forecasting\n", "\n", "Generate a forecasting dataset about global military strikes and attack operations using the LightningRod SDK. Fine-tune a model via RL that outperforms frontier LLMs on strike prediction." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -36,8 +35,7 @@ "## Set up the client\n", "\n", "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -59,8 +57,7 @@ "## Build the pipeline\n", "\n", "Configure the pipeline with domain-specific instructions and examples for military strike forecasting. Covers airstrikes, missile strikes, drone strikes, and naval strikes across state and non-state actors globally." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -178,8 +175,7 @@ "## Run the pipeline\n", "\n", "Collect news articles, generate questions, and label answers. Set `max_questions=10000` for a full production run; reduce for testing." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -201,8 +197,7 @@ "## Prepare the dataset\n", "\n", "Filter valid samples, deduplicate, and split into train/test sets using a temporal strategy." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -232,8 +227,7 @@ "## Train the model\n", "\n", "Fine-tune `openai/gpt-oss-120b` via RL using the training parameters from our golf and WWTD experiments." - ], - "outputs": [] + ] }, { "cell_type": "code", @@ -283,8 +277,7 @@ "## Evaluate\n", "\n", "Run the trained model against the test set, benchmarked against GPT-5.4." - ], - "outputs": [] + ] }, { "cell_type": "code", diff --git a/openapi/openapi.json b/openapi/openapi.json index ff05f24..251ffbb 100644 --- a/openapi/openapi.json +++ b/openapi/openapi.json @@ -2221,20 +2221,12 @@ "dataset": { "$ref": "#/components/schemas/SampleDatasetConfig" }, - "model_id": { - "type": "string", - "title": "Model Id" - }, - "benchmark_model_id": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Benchmark Model Id" + "models": { + "items": { + "$ref": "#/components/schemas/EvalModel" + }, + "type": "array", + "title": "Models" }, "temperature": { "type": "number", @@ -2245,7 +2237,7 @@ "type": "object", "required": [ "dataset", - "model_id" + "models" ], "title": "CreateEvalJobRequest" }, @@ -2833,24 +2825,16 @@ "type": "string", "title": "Organization Id" }, - "model_id": { - "type": "string", - "title": "Model Id" + "models": { + "items": { + "$ref": "#/components/schemas/EvalModel" + }, + "type": "array", + "title": "Models" }, "dataset": { "$ref": "#/components/schemas/SampleDatasetConfig" }, - "benchmark_model_id": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Benchmark Model Id" - }, "temperature": { "type": "number", "title": "Temperature", @@ -2870,7 +2854,7 @@ "type": "object", "required": [ "organization_id", - "model_id", + "models", "dataset" ], "title": "EvalConfig" @@ -3000,6 +2984,30 @@ ], "title": "EvalJobStatus" }, + "EvalModel": { + "properties": { + "model_id": { + "type": "string", + "title": "Model Id" + }, + "label": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Label" + } + }, + "type": "object", + "required": [ + "model_id" + ], + "title": "EvalModel" + }, "EventUsageSummary": { "properties": { "count": { diff --git a/src/lightningrod/__init__.py b/src/lightningrod/__init__.py index 0804f0e..2ff1335 100644 --- a/src/lightningrod/__init__.py +++ b/src/lightningrod/__init__.py @@ -53,6 +53,7 @@ MetadataFieldType, TopicTreeSeedGenerator, CsvSeedGenerator, + EvalModel, ) __version__ = "0.1.19" @@ -108,4 +109,5 @@ "WebSearchLabeler", "TopicTreeSeedGenerator", "CsvSeedGenerator", + "EvalModel", ] diff --git a/src/lightningrod/_display.py b/src/lightningrod/_display.py index ccd418c..bfcb4f5 100644 --- a/src/lightningrod/_display.py +++ b/src/lightningrod/_display.py @@ -172,7 +172,7 @@ def build_training_live_display(job: Any) -> RenderableType: renderables.append(Text("")) if job is not None: if _is_set(job.name) and job.name: - renderables.append(_safe_markup(f" [bold]Job:[/bold] {job.name}")) + renderables.append(_safe_markup(f" [bold]Job ID:[/bold] {job.id}")) renderables.append(Text("")) if job.status == TrainingJobStatus.RUNNING: @@ -258,7 +258,7 @@ def build_eval_live_display(job: EvalJob) -> RenderableType: renderables.append(_safe_markup(f"[bold {header_style}]{header}[/bold {header_style}]")) renderables.append(Text("")) if job is not None: - renderables.append(_safe_markup(f" [bold]Model:[/bold] {job.config.model_id}")) + renderables.append(_safe_markup(f" [bold]Job ID:[/bold] {job.id}")) renderables.append(_safe_markup(f" [bold]Dataset:[/bold] {job.config.dataset.id}")) renderables.append(Text("")) if job.status in (EvalJobStatus.RUNNING, EvalJobStatus.STARTING): @@ -301,8 +301,7 @@ def print_eval(job: EvalJob) -> None: renderables.append(_safe_markup(f"[bold {header_style}]{header}[/bold {header_style}]")) renderables.append(Text("")) if job is not None: - renderables.append(_safe_markup(f" [bold]ID:[/bold] {job.id}")) - renderables.append(_safe_markup(f" [bold]Model:[/bold] {job.config.model_id}")) + renderables.append(_safe_markup(f" [bold]Job ID:[/bold] {job.id}")) renderables.append(_safe_markup(f" [bold]Dataset:[/bold] {job.config.dataset.id}")) renderables.append(Text("")) if _is_set(job.metrics) and job.metrics and job.metrics.additional_properties: diff --git a/src/lightningrod/_generated/models/__init__.py b/src/lightningrod/_generated/models/__init__.py index 3c09af0..21eb37d 100644 --- a/src/lightningrod/_generated/models/__init__.py +++ b/src/lightningrod/_generated/models/__init__.py @@ -32,6 +32,7 @@ from .eval_job_list_response import EvalJobListResponse from .eval_job_metrics_type_0 import EvalJobMetricsType0 from .eval_job_status import EvalJobStatus +from .eval_model import EvalModel from .event_usage_summary import EventUsageSummary from .file_set import FileSet from .file_set_context_generator import FileSetContextGenerator @@ -150,6 +151,7 @@ "EvalJobListResponse", "EvalJobMetricsType0", "EvalJobStatus", + "EvalModel", "EventUsageSummary", "FileSet", "FileSetContextGenerator", diff --git a/src/lightningrod/_generated/models/create_eval_job_request.py b/src/lightningrod/_generated/models/create_eval_job_request.py index ed1330a..e655c89 100644 --- a/src/lightningrod/_generated/models/create_eval_job_request.py +++ b/src/lightningrod/_generated/models/create_eval_job_request.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field @@ -9,6 +9,7 @@ from ..types import UNSET, Unset if TYPE_CHECKING: + from ..models.eval_model import EvalModel from ..models.sample_dataset_config import SampleDatasetConfig @@ -20,27 +21,22 @@ class CreateEvalJobRequest: """ Attributes: dataset (SampleDatasetConfig): - model_id (str): - benchmark_model_id (None | str | Unset): + models (list[EvalModel]): temperature (float | Unset): Default: 0.0. """ dataset: SampleDatasetConfig - model_id: str - benchmark_model_id: None | str | Unset = UNSET + models: list[EvalModel] temperature: float | Unset = 0.0 additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: dataset = self.dataset.to_dict() - model_id = self.model_id - - benchmark_model_id: None | str | Unset - if isinstance(self.benchmark_model_id, Unset): - benchmark_model_id = UNSET - else: - benchmark_model_id = self.benchmark_model_id + models = [] + for models_item_data in self.models: + models_item = models_item_data.to_dict() + models.append(models_item) temperature = self.temperature @@ -49,11 +45,9 @@ def to_dict(self) -> dict[str, Any]: field_dict.update( { "dataset": dataset, - "model_id": model_id, + "models": models, } ) - if benchmark_model_id is not UNSET: - field_dict["benchmark_model_id"] = benchmark_model_id if temperature is not UNSET: field_dict["temperature"] = temperature @@ -61,28 +55,24 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.eval_model import EvalModel from ..models.sample_dataset_config import SampleDatasetConfig d = dict(src_dict) dataset = SampleDatasetConfig.from_dict(d.pop("dataset")) - model_id = d.pop("model_id") - - def _parse_benchmark_model_id(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) + models = [] + _models = d.pop("models") + for models_item_data in _models: + models_item = EvalModel.from_dict(models_item_data) - benchmark_model_id = _parse_benchmark_model_id(d.pop("benchmark_model_id", UNSET)) + models.append(models_item) temperature = d.pop("temperature", UNSET) create_eval_job_request = cls( dataset=dataset, - model_id=model_id, - benchmark_model_id=benchmark_model_id, + models=models, temperature=temperature, ) diff --git a/src/lightningrod/_generated/models/eval_config.py b/src/lightningrod/_generated/models/eval_config.py index 5a732dd..cf0b767 100644 --- a/src/lightningrod/_generated/models/eval_config.py +++ b/src/lightningrod/_generated/models/eval_config.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar from attrs import define as _attrs_define from attrs import field as _attrs_field @@ -9,6 +9,7 @@ from ..types import UNSET, Unset if TYPE_CHECKING: + from ..models.eval_model import EvalModel from ..models.sample_dataset_config import SampleDatasetConfig @@ -20,18 +21,16 @@ class EvalConfig: """ Attributes: organization_id (str): - model_id (str): + models (list[EvalModel]): dataset (SampleDatasetConfig): - benchmark_model_id (None | str | Unset): temperature (float | Unset): Default: 0.0. max_tokens (int | Unset): Default: 8192. max_concurrent (int | Unset): Default: 50. """ organization_id: str - model_id: str + models: list[EvalModel] dataset: SampleDatasetConfig - benchmark_model_id: None | str | Unset = UNSET temperature: float | Unset = 0.0 max_tokens: int | Unset = 8192 max_concurrent: int | Unset = 50 @@ -40,16 +39,13 @@ class EvalConfig: def to_dict(self) -> dict[str, Any]: organization_id = self.organization_id - model_id = self.model_id + models = [] + for models_item_data in self.models: + models_item = models_item_data.to_dict() + models.append(models_item) dataset = self.dataset.to_dict() - benchmark_model_id: None | str | Unset - if isinstance(self.benchmark_model_id, Unset): - benchmark_model_id = UNSET - else: - benchmark_model_id = self.benchmark_model_id - temperature = self.temperature max_tokens = self.max_tokens @@ -61,12 +57,10 @@ def to_dict(self) -> dict[str, Any]: field_dict.update( { "organization_id": organization_id, - "model_id": model_id, + "models": models, "dataset": dataset, } ) - if benchmark_model_id is not UNSET: - field_dict["benchmark_model_id"] = benchmark_model_id if temperature is not UNSET: field_dict["temperature"] = temperature if max_tokens is not UNSET: @@ -78,23 +72,20 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.eval_model import EvalModel from ..models.sample_dataset_config import SampleDatasetConfig d = dict(src_dict) organization_id = d.pop("organization_id") - model_id = d.pop("model_id") - - dataset = SampleDatasetConfig.from_dict(d.pop("dataset")) + models = [] + _models = d.pop("models") + for models_item_data in _models: + models_item = EvalModel.from_dict(models_item_data) - def _parse_benchmark_model_id(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) + models.append(models_item) - benchmark_model_id = _parse_benchmark_model_id(d.pop("benchmark_model_id", UNSET)) + dataset = SampleDatasetConfig.from_dict(d.pop("dataset")) temperature = d.pop("temperature", UNSET) @@ -104,9 +95,8 @@ def _parse_benchmark_model_id(data: object) -> None | str | Unset: eval_config = cls( organization_id=organization_id, - model_id=model_id, + models=models, dataset=dataset, - benchmark_model_id=benchmark_model_id, temperature=temperature, max_tokens=max_tokens, max_concurrent=max_concurrent, diff --git a/src/lightningrod/_generated/models/eval_model.py b/src/lightningrod/_generated/models/eval_model.py new file mode 100644 index 0000000..5512124 --- /dev/null +++ b/src/lightningrod/_generated/models/eval_model.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="EvalModel") + + +@_attrs_define +class EvalModel: + """ + Attributes: + model_id (str): + label (None | str | Unset): + """ + + model_id: str + label: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + model_id = self.model_id + + label: None | str | Unset + if isinstance(self.label, Unset): + label = UNSET + else: + label = self.label + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "model_id": model_id, + } + ) + if label is not UNSET: + field_dict["label"] = label + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + model_id = d.pop("model_id") + + def _parse_label(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + label = _parse_label(d.pop("label", UNSET)) + + eval_model = cls( + model_id=model_id, + label=label, + ) + + eval_model.additional_properties = d + return eval_model + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/src/lightningrod/_generated/models/training_job_model_id_by_step_type_0.py b/src/lightningrod/_generated/models/training_job_model_id_by_step_type_0.py new file mode 100644 index 0000000..28555b3 --- /dev/null +++ b/src/lightningrod/_generated/models/training_job_model_id_by_step_type_0.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +T = TypeVar("T", bound="TrainingJobModelIdByStepType0") + + +@_attrs_define +class TrainingJobModelIdByStepType0: + """ """ + + additional_properties: dict[str, str] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + training_job_model_id_by_step_type_0 = cls() + + training_job_model_id_by_step_type_0.additional_properties = d + return training_job_model_id_by_step_type_0 + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> str: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: str) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/src/lightningrod/_generated/models/validation_error_context.py b/src/lightningrod/_generated/models/validation_error_context.py new file mode 100644 index 0000000..cfaf7b0 --- /dev/null +++ b/src/lightningrod/_generated/models/validation_error_context.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +T = TypeVar("T", bound="ValidationErrorContext") + + +@_attrs_define +class ValidationErrorContext: + """ """ + + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + validation_error_context = cls() + + validation_error_context.additional_properties = d + return validation_error_context + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/src/lightningrod/training/client.py b/src/lightningrod/training/client.py index 521fb3e..e7b7c09 100644 --- a/src/lightningrod/training/client.py +++ b/src/lightningrod/training/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from attrs import define from lightningrod._display import _is_notebook, display_error, run_training_live_display from lightningrod._generated.api.training_jobs import ( diff --git a/src/lightningrod/training/evals.py b/src/lightningrod/training/evals.py index 6af24b8..2a7df7b 100644 --- a/src/lightningrod/training/evals.py +++ b/src/lightningrod/training/evals.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from lightningrod._display import _is_notebook, display_error, run_eval_live_display from lightningrod._generated.api.evaluations import ( create_eval_job_evaluations_post, @@ -5,7 +7,7 @@ list_eval_jobs_evaluations_get, ) from lightningrod._generated.client import AuthenticatedClient -from lightningrod._generated.models import SampleDatasetConfig +from lightningrod._generated.models import EvalModel, SampleDatasetConfig from lightningrod._generated.models.create_eval_job_request import CreateEvalJobRequest from lightningrod._generated.models.eval_job import EvalJob from lightningrod._generated.models.eval_job_list_response import EvalJobListResponse @@ -21,17 +23,13 @@ def __init__(self, client: AuthenticatedClient): def create( self, - model_id: str, dataset: "SampleDataset", - benchmark_model_id: str | None = None, - temperature: float = 0.0, + models: list[EvalModel], ) -> EvalJob: dataset_config = sample_dataset_to_config(dataset) body = CreateEvalJobRequest( - model_id=model_id, + models=models, dataset=dataset_config, - benchmark_model_id=benchmark_model_id if benchmark_model_id is not None else UNSET, - temperature=temperature, ) response = create_eval_job_evaluations_post.sync_detailed( client=self._client, @@ -61,18 +59,10 @@ def list( def run( self, - model_id: str, + models: list[EvalModel], dataset: "SampleDataset", - benchmark_model_id: str | None = None, - temperature: float = 0.0, - poll_interval: float = 15, ) -> EvalJob: - job = self.create( - model_id=model_id, - dataset=dataset, - benchmark_model_id=benchmark_model_id, - temperature=temperature, - ) + job = self.create(models=models, dataset=dataset) if job.status == TrainingJobStatus.FAILED: error_msg = ( @@ -89,5 +79,5 @@ def poll() -> EvalJob: job = self.get(job.id) return job - run_eval_live_display(poll, poll_interval=poll_interval, initial_job=job) + run_eval_live_display(poll, initial_job=job) return job From b6f29754e0cec5b87dc6872e9e9d55628962c172 Mon Sep 17 00:00:00 2001 From: Bartolomej Kozorog Date: Fri, 27 Mar 2026 18:01:20 +0100 Subject: [PATCH 2/3] update millitary strikes notebook --- .../fine_tuning/04_military_strikes.ipynb | 659 +++++++++--------- 1 file changed, 331 insertions(+), 328 deletions(-) diff --git a/notebooks/fine_tuning/04_military_strikes.ipynb b/notebooks/fine_tuning/04_military_strikes.ipynb index 209f285..b70b0fd 100644 --- a/notebooks/fine_tuning/04_military_strikes.ipynb +++ b/notebooks/fine_tuning/04_military_strikes.ipynb @@ -1,331 +1,334 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Military Strikes Forecasting\n", - "\n", - "Generate a forecasting dataset about global military strikes and attack operations using the LightningRod SDK. Fine-tune a model via RL that outperforms frontier LLMs on strike prediction." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Military Strikes Forecasting\n", + "\n", + "Generate a forecasting dataset about global military strikes and attack operations using the LightningRod SDK. Fine-tune a model via RL that outperforms frontier LLMs on strike prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install lightningrod-ai python-dotenv pandas\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()\n", + "\n", + "from datetime import datetime\n", + "\n", + "import pandas as pd\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up the client\n", + "\n", + "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import LightningRod\n", + "from lightningrod.utils import config\n", + "\n", + "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", + "lr = LightningRod(api_key=api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the pipeline\n", + "\n", + "Configure the pipeline with domain-specific instructions and examples for military strike forecasting. Covers airstrikes, missile strikes, drone strikes, and naval strikes across state and non-state actors globally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "instructions = \"\"\"\n", + "Generate binary forecasting questions specifically about military strikes and attack operations.\n", + "\n", + "Cover all strike types:\n", + "- Airstrikes (fighter jets, bombers, helicopter gunships)\n", + "- Missile strikes (ballistic, cruise, hypersonic)\n", + "- Drone strikes (kamikaze drones, armed UAVs, drone swarms)\n", + "- Naval strikes (ship-launched missiles, naval gunfire, submarine attacks)\n", + "\n", + "Cover both state and non-state actors. Use the natural language of news reporting:\n", + "- State actors: country names, leader names, named military units (IDF, IRGC, Pentagon)\n", + "- Non-state actors: group names (Hamas, Houthis, Hezbollah, ISIS, Wagner)\n", + "\n", + "Questions must:\n", + "- Name the specific actor conducting the strike\n", + "- Name the specific target (location, infrastructure, military asset, or group)\n", + "- Have a specific date or event milestone as resolution criteria\n", + "- Be objectively verifiable from open-source news\n", + "- Span the full probability spectrum \\u2014 some likely, some unlikely, some that won't happen\n", + "\"\"\"\n", + "\n", + "good_examples = [\n", + " \"Will the IDF conduct airstrikes on Hezbollah weapons depots in the Bekaa Valley before November 2024?\",\n", + " \"Will US Air Force B-52s conduct strikes on Houthi military infrastructure in Yemen before March 2024?\",\n", + " \"Will Russian Su-34s carry out airstrikes on Kharkiv civilian infrastructure before June 2024?\",\n", + " \"Will Iran launch a direct ballistic missile strike on Israeli territory before May 2024?\",\n", + " \"Will Houthi forces fire anti-ship missiles at US Navy destroyers in the Red Sea before February 2024?\",\n", + " \"Will Ukraine conduct drone strikes on Russian oil refineries inside Russian territory before April 2024?\",\n", + " \"Will North Korea launch an ICBM test over Japanese waters before January 2024?\",\n", + " \"Will Houthi forces attack a commercial vessel with drone boats in the Red Sea before March 2024?\",\n", + " \"Will the US Navy conduct ship-launched Tomahawk strikes on Houthi radar sites before February 2024?\",\n", + " \"Will Russian forces launch an armored offensive toward Chasiv Yar before April 2024?\",\n", + " \"Will Israel conduct airstrikes on Iranian nuclear facilities at Natanz before December 2024?\",\n", + " \"Will NATO aircraft conduct strikes inside Russian territory before January 2025?\",\n", + " \"Will China conduct missile strikes on Taiwanese military bases before the end of 2024?\",\n", + " \"Will Pakistan conduct airstrikes on Afghan Taliban positions before March 2024?\",\n", + "]\n", + "\n", + "bad_examples = [\n", + " \"Will there be an attack somewhere? (no specific actor, target, or location)\",\n", + " \"Will violence increase in the Middle East? (vague, not a specific strike event)\",\n", + " \"Will conflict continue in Ukraine? (trivially obvious, not a specific strike)\",\n", + " \"Will missiles be fired? (no actor, no target, no date)\",\n", + " \"Will tensions escalate? (not a verifiable strike event)\",\n", + " \"Will there be drone activity near the border? (too vague to verify)\",\n", + " \"Will the situation get worse? (subjective, not measurable)\",\n", + " \"Will airstrikes happen in 2025? (no actor, no target, too broad)\",\n", + " \"Will someone retaliate? (no specific actor or method)\",\n", + " \"Will the war end? (not a strike event, different question type)\",\n", + "]\n", + "\n", + "search_queries = [\n", + " \"military airstrike\",\n", + " \"military strike\",\n", + " \"missile strike\",\n", + " \"drone strike\",\n", + " \"naval strike\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import (\n", + " BinaryAnswerType,\n", + " NewsSeedGenerator,\n", + " ForwardLookingQuestionGenerator,\n", + " NewsContextGenerator,\n", + " WebSearchLabeler,\n", + " QuestionPipeline,\n", + ")\n", + "\n", + "answer_type = BinaryAnswerType()\n", + "\n", + "pipeline = QuestionPipeline(\n", + " seed_generator=NewsSeedGenerator(\n", + " start_date=datetime(2024, 6, 1),\n", + " end_date=datetime(2026, 3, 1),\n", + " interval_duration_days=7,\n", + " search_query=search_queries,\n", + " articles_per_search=10,\n", + " ),\n", + " question_generator=ForwardLookingQuestionGenerator(\n", + " instructions=instructions,\n", + " examples=good_examples,\n", + " bad_examples=bad_examples,\n", + " answer_type=answer_type,\n", + " questions_per_seed=5,\n", + " ),\n", + " context_generators=[\n", + " NewsContextGenerator(\n", + " articles_per_query=3,\n", + " num_search_queries=3,\n", + " num_articles=5,\n", + " )\n", + " ],\n", + " labeler=WebSearchLabeler(answer_type=answer_type),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the pipeline\n", + "\n", + "Collect news articles, generate questions, and label answers. Set `max_questions=10000` for a full production run; reduce for testing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = lr.transforms.run(pipeline, max_questions=10000, name=\"Military strikes forecasting\")\n", + "\n", + "samples = dataset.download()\n", + "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n", + "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare the dataset\n", + "\n", + "Filter valid samples, deduplicate, and split into train/test sets using a temporal strategy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import filter_and_split\n", + "\n", + "train_dataset, test_dataset = filter_and_split(\n", + " dataset,\n", + " test_size=0.2,\n", + " split_strategy=\"temporal\",\n", + " days_to_resolution_range=(1, 90),\n", + ")\n", + "\n", + "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", + " data = ds.flattened()\n", + " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", + " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model\n", + "\n", + "Fine-tune `openai/gpt-oss-120b` via RL using the training parameters from our golf and WWTD experiments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import TrainingConfig\n", + "BATCH_SIZE = 32 \n", + "\n", + "train_data = train_dataset.flattened() \n", + "training_steps = max(10, len(train_data) // BATCH_SIZE)\n", + "\n", + "training_config = TrainingConfig(\n", + " base_model=\"openai/gpt-oss-120b\",\n", + " training_steps=training_steps,\n", + " lora_rank=32,\n", + " batch_size=BATCH_SIZE,\n", + " num_rollouts=8,\n", + " max_response_length=16384,\n", + " learning_rate=4e-5,\n", + ")\n", + "\n", + "cost_estimate = lr.training.estimate_cost(training_config, dataset=train_dataset)\n", + "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", + "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", + "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", + "if cost_estimate.notes:\n", + " print(f\"Notes: {cost_estimate.notes}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "job = lr.training.run(training_config, dataset=train_dataset, name=\"military-strikes-v1\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate\n", + "\n", + "Run the trained model against the test set, benchmarked against GPT-5.4." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import EvalModel, training\n", + "\n", + "eval_job = lr.evals.run(\n", + " models=[\n", + " EvalModel(model_id=training_config.base_model, label=\"Base\"),\n", + " EvalModel(model_id=job.model_id, label=\"Fine-tuned\"),\n", + " EvalModel(model_id=\"openai/gpt-5.4\", label=\"GPT-5.4\"),\n", + " ],\n", + " dataset=test_dataset,\n", + ")\n", + "\n", + "training.print_eval(eval_job)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Quick inference test\n", + "print(lr.predict(job.model_id, \"Will Israel conduct airstrikes in southern Lebanon before April 2026?\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install lightningrod-ai python-dotenv pandas\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()\n", - "\n", - "from datetime import datetime\n", - "\n", - "import pandas as pd\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Build the pipeline\n", - "\n", - "Configure the pipeline with domain-specific instructions and examples for military strike forecasting. Covers airstrikes, missile strikes, drone strikes, and naval strikes across state and non-state actors globally." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "instructions = \"\"\"\n", - "Generate binary forecasting questions specifically about military strikes and attack operations.\n", - "\n", - "Cover all strike types:\n", - "- Airstrikes (fighter jets, bombers, helicopter gunships)\n", - "- Missile strikes (ballistic, cruise, hypersonic)\n", - "- Drone strikes (kamikaze drones, armed UAVs, drone swarms)\n", - "- Naval strikes (ship-launched missiles, naval gunfire, submarine attacks)\n", - "\n", - "Cover both state and non-state actors. Use the natural language of news reporting:\n", - "- State actors: country names, leader names, named military units (IDF, IRGC, Pentagon)\n", - "- Non-state actors: group names (Hamas, Houthis, Hezbollah, ISIS, Wagner)\n", - "\n", - "Questions must:\n", - "- Name the specific actor conducting the strike\n", - "- Name the specific target (location, infrastructure, military asset, or group)\n", - "- Have a specific date or event milestone as resolution criteria\n", - "- Be objectively verifiable from open-source news\n", - "- Span the full probability spectrum \\u2014 some likely, some unlikely, some that won't happen\n", - "\"\"\"\n", - "\n", - "good_examples = [\n", - " \"Will the IDF conduct airstrikes on Hezbollah weapons depots in the Bekaa Valley before November 2024?\",\n", - " \"Will US Air Force B-52s conduct strikes on Houthi military infrastructure in Yemen before March 2024?\",\n", - " \"Will Russian Su-34s carry out airstrikes on Kharkiv civilian infrastructure before June 2024?\",\n", - " \"Will Iran launch a direct ballistic missile strike on Israeli territory before May 2024?\",\n", - " \"Will Houthi forces fire anti-ship missiles at US Navy destroyers in the Red Sea before February 2024?\",\n", - " \"Will Ukraine conduct drone strikes on Russian oil refineries inside Russian territory before April 2024?\",\n", - " \"Will North Korea launch an ICBM test over Japanese waters before January 2024?\",\n", - " \"Will Houthi forces attack a commercial vessel with drone boats in the Red Sea before March 2024?\",\n", - " \"Will the US Navy conduct ship-launched Tomahawk strikes on Houthi radar sites before February 2024?\",\n", - " \"Will Russian forces launch an armored offensive toward Chasiv Yar before April 2024?\",\n", - " \"Will Israel conduct airstrikes on Iranian nuclear facilities at Natanz before December 2024?\",\n", - " \"Will NATO aircraft conduct strikes inside Russian territory before January 2025?\",\n", - " \"Will China conduct missile strikes on Taiwanese military bases before the end of 2024?\",\n", - " \"Will Pakistan conduct airstrikes on Afghan Taliban positions before March 2024?\",\n", - "]\n", - "\n", - "bad_examples = [\n", - " \"Will there be an attack somewhere? (no specific actor, target, or location)\",\n", - " \"Will violence increase in the Middle East? (vague, not a specific strike event)\",\n", - " \"Will conflict continue in Ukraine? (trivially obvious, not a specific strike)\",\n", - " \"Will missiles be fired? (no actor, no target, no date)\",\n", - " \"Will tensions escalate? (not a verifiable strike event)\",\n", - " \"Will there be drone activity near the border? (too vague to verify)\",\n", - " \"Will the situation get worse? (subjective, not measurable)\",\n", - " \"Will airstrikes happen in 2025? (no actor, no target, too broad)\",\n", - " \"Will someone retaliate? (no specific actor or method)\",\n", - " \"Will the war end? (not a strike event, different question type)\",\n", - "]\n", - "\n", - "search_queries = [\n", - " \"military airstrike\",\n", - " \"military strike\",\n", - " \"missile strike\",\n", - " \"drone strike\",\n", - " \"naval strike\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import (\n", - " BinaryAnswerType,\n", - " NewsSeedGenerator,\n", - " ForwardLookingQuestionGenerator,\n", - " NewsContextGenerator,\n", - " WebSearchLabeler,\n", - " QuestionPipeline,\n", - ")\n", - "\n", - "answer_type = BinaryAnswerType()\n", - "\n", - "pipeline = QuestionPipeline(\n", - " seed_generator=NewsSeedGenerator(\n", - " start_date=datetime(2024, 6, 1),\n", - " end_date=datetime(2026, 3, 1),\n", - " interval_duration_days=7,\n", - " search_query=search_queries,\n", - " articles_per_search=10,\n", - " ),\n", - " question_generator=ForwardLookingQuestionGenerator(\n", - " instructions=instructions,\n", - " examples=good_examples,\n", - " bad_examples=bad_examples,\n", - " answer_type=answer_type,\n", - " questions_per_seed=5,\n", - " ),\n", - " context_generators=[\n", - " NewsContextGenerator(\n", - " articles_per_query=3,\n", - " num_search_queries=3,\n", - " num_articles=5,\n", - " )\n", - " ],\n", - " labeler=WebSearchLabeler(answer_type=answer_type),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run the pipeline\n", - "\n", - "Collect news articles, generate questions, and label answers. Set `max_questions=10000` for a full production run; reduce for testing." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = lr.transforms.run(pipeline, max_questions=10000, name=\"Military strikes forecasting\")\n", - "\n", - "samples = dataset.download()\n", - "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n", - "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare the dataset\n", - "\n", - "Filter valid samples, deduplicate, and split into train/test sets using a temporal strategy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import filter_and_split\n", - "\n", - "train_dataset, test_dataset = filter_and_split(\n", - " dataset,\n", - " test_size=0.2,\n", - " split_strategy=\"temporal\",\n", - " days_to_resolution_range=(1, 90),\n", - ")\n", - "\n", - "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", - " data = ds.flattened()\n", - " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", - " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train the model\n", - "\n", - "Fine-tune `openai/gpt-oss-120b` via RL using the training parameters from our golf and WWTD experiments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import TrainingConfig\n", - "BATCH_SIZE = 32 \n", - "\n", - "train_data = train_dataset.flattened() \n", - "training_steps = max(10, len(train_data) // BATCH_SIZE)\n", - "\n", - "training_config = TrainingConfig(\n", - " base_model=\"openai/gpt-oss-120b\",\n", - " training_steps=training_steps,\n", - " lora_rank=32,\n", - " batch_size=BATCH_SIZE,\n", - " num_rollouts=8,\n", - " max_response_length=16384,\n", - " learning_rate=4e-5,\n", - ")\n", - "\n", - "cost_estimate = lr.training.estimate_cost(training_config, dataset=train_dataset)\n", - "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", - "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", - "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", - "if cost_estimate.notes:\n", - " print(f\"Notes: {cost_estimate.notes}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "job = lr.training.run(training_config, dataset=train_dataset, name=\"military-strikes-v1\")\n", - "print(f\"Job {job.id} completed with status: {job.status}\")\n", - "print(f\"Model ID: {job.model_id}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Evaluate\n", - "\n", - "Run the trained model against the test set, benchmarked against GPT-5.4." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "eval_job = lr.evals.run(\n", - " model_id=job.model_id,\n", - " dataset=test_dataset,\n", - " benchmark_model_id=\"openai/gpt-5.4\",\n", - ")\n", - "\n", - "print(f\"Eval completed: {eval_job.id}\")\n", - "if eval_job.metrics:\n", - " print(eval_job.metrics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Quick inference test\n", - "print(lr.predict(job.model_id, \"Will Israel conduct airstrikes in southern Lebanon before April 2026?\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } From 74844ce6441c135746b5ce9cea02cdae19ebf94f Mon Sep 17 00:00:00 2001 From: Bartolomej Kozorog Date: Fri, 27 Mar 2026 19:14:37 +0100 Subject: [PATCH 3/3] regenerate client --- openapi/openapi.json | 195 +++--------------- .../_generated/models/__init__.py | 8 +- .../models/create_transform_job_request.py | 52 +---- .../_generated/models/csv_seed_generator.py | 152 -------------- .../models/estimate_cost_request.py | 52 +---- .../_generated/models/question_pipeline.py | 39 +--- .../models/topic_tree_seed_generator.py | 149 ------------- .../_generated/models/training_job.py | 54 +++-- .../_generated/models/validation_error.py | 35 +++- 9 files changed, 125 insertions(+), 611 deletions(-) delete mode 100644 src/lightningrod/_generated/models/csv_seed_generator.py delete mode 100644 src/lightningrod/_generated/models/topic_tree_seed_generator.py diff --git a/openapi/openapi.json b/openapi/openapi.json index 251ffbb..459405b 100644 --- a/openapi/openapi.json +++ b/openapi/openapi.json @@ -5,11 +5,6 @@ "description": "Generate verified, grounded datasets at scale for LLM fine-tuning and evaluation", "version": "1.0.0" }, - "servers": [ - { - "url": "/api/public/v1" - } - ], "paths": { "/openai/chat/completions": { "post": { @@ -2481,12 +2476,6 @@ { "$ref": "#/components/schemas/NewsSeedGenerator" }, - { - "$ref": "#/components/schemas/TopicTreeSeedGenerator" - }, - { - "$ref": "#/components/schemas/CsvSeedGenerator" - }, { "$ref": "#/components/schemas/QuestionAndLabelGenerator" }, @@ -2506,7 +2495,6 @@ "discriminator": { "propertyName": "config_type", "mapping": { - "CSV_SEED_GENERATOR": "#/components/schemas/CsvSeedGenerator", "FILESET_QUERY_SEED_GENERATOR": "#/components/schemas/FileSetQuerySeedGenerator", "FILESET_SEED_GENERATOR": "#/components/schemas/FileSetSeedGenerator", "FORWARD_LOOKING_QUESTION_GENERATOR": "#/components/schemas/ForwardLookingQuestionGenerator", @@ -2516,7 +2504,6 @@ "QUESTION_GENERATOR": "#/components/schemas/QuestionGenerator", "QUESTION_PIPELINE": "#/components/schemas/QuestionPipeline", "QUESTION_RENDERER": "#/components/schemas/QuestionRenderer", - "TOPIC_TREE_SEED_GENERATOR": "#/components/schemas/TopicTreeSeedGenerator", "WEB_SEARCH_LABELER": "#/components/schemas/WebSearchLabeler" } } @@ -2588,73 +2575,6 @@ ], "title": "CreateTransformJobRequest" }, - "CsvSeedGenerator": { - "properties": { - "config_type": { - "type": "string", - "const": "CSV_SEED_GENERATOR", - "title": "Config Type", - "description": "Type of transform configuration", - "default": "CSV_SEED_GENERATOR" - }, - "file_id": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "type": "string" - }, - "type": "array" - } - ], - "title": "File Id", - "description": "OrgFile ID(s) from POST /files/ upload response" - }, - "seed_text_column": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Seed Text Column", - "description": "Column name for Seed.seed_text; if None, serialize entire row as JSON" - }, - "label_column": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Label Column", - "description": "Column name for pre-existing labels (populates Sample.label)" - }, - "date_column": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Date Column", - "description": "Column name for Seed.seed_creation_date" - } - }, - "type": "object", - "required": [ - "file_id" - ], - "title": "CsvSeedGenerator" - }, "DatasetMetadata": { "properties": { "id": { @@ -5293,12 +5213,6 @@ { "$ref": "#/components/schemas/FileSetQuerySeedGenerator" }, - { - "$ref": "#/components/schemas/TopicTreeSeedGenerator" - }, - { - "$ref": "#/components/schemas/CsvSeedGenerator" - }, { "$ref": "#/components/schemas/MockTransformConfig" } @@ -5307,13 +5221,11 @@ "propertyName": "config_type", "mapping": { "BIGQUERY_SEED_GENERATOR": "#/components/schemas/BigQuerySeedGenerator", - "CSV_SEED_GENERATOR": "#/components/schemas/CsvSeedGenerator", "FILESET_QUERY_SEED_GENERATOR": "#/components/schemas/FileSetQuerySeedGenerator", "FILESET_SEED_GENERATOR": "#/components/schemas/FileSetSeedGenerator", "GDELT_SEED_GENERATOR": "#/components/schemas/GdeltSeedGenerator", "MOCK": "#/components/schemas/MockTransformConfig", - "NEWS_SEED_GENERATOR": "#/components/schemas/NewsSeedGenerator", - "TOPIC_TREE_SEED_GENERATOR": "#/components/schemas/TopicTreeSeedGenerator" + "NEWS_SEED_GENERATOR": "#/components/schemas/NewsSeedGenerator" } } }, @@ -6092,72 +6004,6 @@ "title": "TemporalConstraint", "description": "Temporal filtering direction relative to the seed document's date.\n\nUses the `file_date` metadata key (unix timestamp int) stored on each\nGemini document by fileset_file_processor.\n\nBEFORE: file_date <= seed_timestamp (context: no lookahead bias)\nAFTER: file_date > seed_timestamp (labels: find resolutions)" }, - "TopicTreeSeedGenerator": { - "properties": { - "config_type": { - "type": "string", - "const": "TOPIC_TREE_SEED_GENERATOR", - "title": "Config Type", - "description": "Type of transform configuration", - "default": "TOPIC_TREE_SEED_GENERATOR" - }, - "topic": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "type": "string" - }, - "type": "array" - } - ], - "title": "Topic", - "description": "Root topic(s) to decompose into subtopic trees" - }, - "tree_depth": { - "type": "integer", - "maximum": 10.0, - "minimum": 1.0, - "title": "Tree Depth", - "description": "Levels of recursive expansion", - "default": 2 - }, - "tree_degree": { - "type": "integer", - "maximum": 20.0, - "minimum": 2.0, - "title": "Tree Degree", - "description": "Subtopics generated per node", - "default": 5 - }, - "model_name": { - "type": "string", - "title": "Model Name", - "description": "LLM model for subtopic generation", - "default": "google/gemini-3-flash-preview" - }, - "model_system_prompt": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Model System Prompt", - "description": "Optional system prompt for the LLM" - } - }, - "type": "object", - "required": [ - "topic" - ], - "title": "TopicTreeSeedGenerator", - "description": "Generates diverse seeds by recursively decomposing broad topics into specific subtopics.\n\nGiven one or more root topics, uses an LLM to break each into `tree_degree` subtopics,\nthen repeats `tree_depth` levels deep. The leaf paths become seeds for downstream transforms.\nThis produces tree_degree^tree_depth seeds per root topic, each more specific than the root.\n\nExample:\n TopicTreeSeedGenerator(topic=\"AI Regulation\", tree_depth=2, tree_degree=4)\n # Produces 16 specific seeds like \"AI Regulation \u2192 Healthcare \u2192 FDA approval of\n # diagnostic algorithms\" that feed into QuestionGenerator or other downstream transforms." - }, "TrainingConfig": { "properties": { "dataset": { @@ -6317,6 +6163,20 @@ ], "title": "Model Id" }, + "model_id_by_step": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Model Id By Step" + }, "reward_history": { "anyOf": [ { @@ -6364,17 +6224,6 @@ ], "title": "Cost Dollars" }, - "dataset_hf_repo": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Dataset Hf Repo" - }, "error_message": { "anyOf": [ { @@ -6807,6 +6656,13 @@ "type": { "type": "string", "title": "Error Type" + }, + "input": { + "title": "Input" + }, + "ctx": { + "type": "object", + "title": "Context" } }, "type": "object", @@ -6883,5 +6739,10 @@ "scheme": "bearer" } } - } + }, + "servers": [ + { + "url": "/api/public/v1" + } + ] } \ No newline at end of file diff --git a/src/lightningrod/_generated/models/__init__.py b/src/lightningrod/_generated/models/__init__.py index 21eb37d..ce0db44 100644 --- a/src/lightningrod/_generated/models/__init__.py +++ b/src/lightningrod/_generated/models/__init__.py @@ -21,7 +21,6 @@ from .create_file_upload_response_metadata_type_0 import CreateFileUploadResponseMetadataType0 from .create_training_job_request import CreateTrainingJobRequest from .create_transform_job_request import CreateTransformJobRequest -from .csv_seed_generator import CsvSeedGenerator from .dataset_metadata import DatasetMetadata from .estimate_cost_request import EstimateCostRequest from .estimate_cost_response import EstimateCostResponse @@ -99,10 +98,10 @@ from .step_cost_breakdown import StepCostBreakdown from .template_question_generator import TemplateQuestionGenerator from .temporal_constraint import TemporalConstraint -from .topic_tree_seed_generator import TopicTreeSeedGenerator from .training_config import TrainingConfig from .training_job import TrainingJob from .training_job_list_response import TrainingJobListResponse +from .training_job_model_id_by_step_type_0 import TrainingJobModelIdByStepType0 from .training_job_status import TrainingJobStatus from .transform_job import TransformJob from .transform_job_status import TransformJobStatus @@ -116,6 +115,7 @@ from .usage_summary_llm_by_model import UsageSummaryLlmByModel from .validate_sample_response import ValidateSampleResponse from .validation_error import ValidationError +from .validation_error_context import ValidationErrorContext from .web_search_labeler import WebSearchLabeler __all__ = ( @@ -140,7 +140,6 @@ "CreateFileUploadResponseMetadataType0", "CreateTrainingJobRequest", "CreateTransformJobRequest", - "CsvSeedGenerator", "DatasetMetadata", "EstimateCostRequest", "EstimateCostResponse", @@ -218,10 +217,10 @@ "StepCostBreakdown", "TemplateQuestionGenerator", "TemporalConstraint", - "TopicTreeSeedGenerator", "TrainingConfig", "TrainingJob", "TrainingJobListResponse", + "TrainingJobModelIdByStepType0", "TrainingJobStatus", "TransformJob", "TransformJobStatus", @@ -235,5 +234,6 @@ "UsageSummaryLlmByModel", "ValidateSampleResponse", "ValidationError", + "ValidationErrorContext", "WebSearchLabeler", ) diff --git a/src/lightningrod/_generated/models/create_transform_job_request.py b/src/lightningrod/_generated/models/create_transform_job_request.py index 802f164..dc567f4 100644 --- a/src/lightningrod/_generated/models/create_transform_job_request.py +++ b/src/lightningrod/_generated/models/create_transform_job_request.py @@ -9,7 +9,6 @@ from ..types import UNSET, Unset if TYPE_CHECKING: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -19,7 +18,6 @@ from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler @@ -30,9 +28,9 @@ class CreateTransformJobRequest: """ Attributes: - config (CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | - GdeltSeedGenerator | NewsSeedGenerator | QuestionAndLabelGenerator | QuestionGenerator | QuestionPipeline | - QuestionRenderer | TopicTreeSeedGenerator | WebSearchLabeler): + config (FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator + | NewsSeedGenerator | QuestionAndLabelGenerator | QuestionGenerator | QuestionPipeline | QuestionRenderer | + WebSearchLabeler): input_dataset_id (None | str | Unset): max_questions (int | None | Unset): max_cost_dollars (float | None | Unset): @@ -41,8 +39,7 @@ class CreateTransformJobRequest: """ config: ( - CsvSeedGenerator - | FileSetQuerySeedGenerator + FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator @@ -51,7 +48,6 @@ class CreateTransformJobRequest: | QuestionGenerator | QuestionPipeline | QuestionRenderer - | TopicTreeSeedGenerator | WebSearchLabeler ) input_dataset_id: None | str | Unset = UNSET @@ -62,7 +58,6 @@ class CreateTransformJobRequest: additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -72,7 +67,6 @@ def to_dict(self) -> dict[str, Any]: from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator config: dict[str, Any] if isinstance(self.config, ForwardLookingQuestionGenerator): @@ -85,10 +79,6 @@ def to_dict(self) -> dict[str, Any]: config = self.config.to_dict() elif isinstance(self.config, NewsSeedGenerator): config = self.config.to_dict() - elif isinstance(self.config, TopicTreeSeedGenerator): - config = self.config.to_dict() - elif isinstance(self.config, CsvSeedGenerator): - config = self.config.to_dict() elif isinstance(self.config, QuestionAndLabelGenerator): config = self.config.to_dict() elif isinstance(self.config, QuestionGenerator): @@ -152,7 +142,6 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -162,7 +151,6 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler d = dict(src_dict) @@ -170,8 +158,7 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: def _parse_config( data: object, ) -> ( - CsvSeedGenerator - | FileSetQuerySeedGenerator + FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator @@ -180,7 +167,6 @@ def _parse_config( | QuestionGenerator | QuestionPipeline | QuestionRenderer - | TopicTreeSeedGenerator | WebSearchLabeler ): try: @@ -226,7 +212,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_5 = TopicTreeSeedGenerator.from_dict(data) + componentsschemas_create_transform_config_type_5 = QuestionAndLabelGenerator.from_dict(data) return componentsschemas_create_transform_config_type_5 except (TypeError, ValueError, AttributeError, KeyError): @@ -234,7 +220,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_6 = CsvSeedGenerator.from_dict(data) + componentsschemas_create_transform_config_type_6 = QuestionGenerator.from_dict(data) return componentsschemas_create_transform_config_type_6 except (TypeError, ValueError, AttributeError, KeyError): @@ -242,7 +228,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_7 = QuestionAndLabelGenerator.from_dict(data) + componentsschemas_create_transform_config_type_7 = QuestionPipeline.from_dict(data) return componentsschemas_create_transform_config_type_7 except (TypeError, ValueError, AttributeError, KeyError): @@ -250,32 +236,16 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_8 = QuestionGenerator.from_dict(data) + componentsschemas_create_transform_config_type_8 = QuestionRenderer.from_dict(data) return componentsschemas_create_transform_config_type_8 except (TypeError, ValueError, AttributeError, KeyError): pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_create_transform_config_type_9 = QuestionPipeline.from_dict(data) - - return componentsschemas_create_transform_config_type_9 - except (TypeError, ValueError, AttributeError, KeyError): - pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_create_transform_config_type_10 = QuestionRenderer.from_dict(data) - - return componentsschemas_create_transform_config_type_10 - except (TypeError, ValueError, AttributeError, KeyError): - pass if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_11 = WebSearchLabeler.from_dict(data) + componentsschemas_create_transform_config_type_9 = WebSearchLabeler.from_dict(data) - return componentsschemas_create_transform_config_type_11 + return componentsschemas_create_transform_config_type_9 config = _parse_config(d.pop("config")) diff --git a/src/lightningrod/_generated/models/csv_seed_generator.py b/src/lightningrod/_generated/models/csv_seed_generator.py deleted file mode 100644 index 6229e71..0000000 --- a/src/lightningrod/_generated/models/csv_seed_generator.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Literal, TypeVar, cast - -from attrs import define as _attrs_define -from attrs import field as _attrs_field - -from ..types import UNSET, Unset - -T = TypeVar("T", bound="CsvSeedGenerator") - - -@_attrs_define -class CsvSeedGenerator: - """ - Attributes: - file_id (list[str] | str): OrgFile ID(s) from POST /files/ upload response - config_type (Literal['CSV_SEED_GENERATOR'] | Unset): Type of transform configuration Default: - 'CSV_SEED_GENERATOR'. - seed_text_column (None | str | Unset): Column name for Seed.seed_text; if None, serialize entire row as JSON - label_column (None | str | Unset): Column name for pre-existing labels (populates Sample.label) - date_column (None | str | Unset): Column name for Seed.seed_creation_date - """ - - file_id: list[str] | str - config_type: Literal["CSV_SEED_GENERATOR"] | Unset = "CSV_SEED_GENERATOR" - seed_text_column: None | str | Unset = UNSET - label_column: None | str | Unset = UNSET - date_column: None | str | Unset = UNSET - additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) - - def to_dict(self) -> dict[str, Any]: - file_id: list[str] | str - if isinstance(self.file_id, list): - file_id = self.file_id - - else: - file_id = self.file_id - - config_type = self.config_type - - seed_text_column: None | str | Unset - if isinstance(self.seed_text_column, Unset): - seed_text_column = UNSET - else: - seed_text_column = self.seed_text_column - - label_column: None | str | Unset - if isinstance(self.label_column, Unset): - label_column = UNSET - else: - label_column = self.label_column - - date_column: None | str | Unset - if isinstance(self.date_column, Unset): - date_column = UNSET - else: - date_column = self.date_column - - field_dict: dict[str, Any] = {} - field_dict.update(self.additional_properties) - field_dict.update( - { - "file_id": file_id, - } - ) - if config_type is not UNSET: - field_dict["config_type"] = config_type - if seed_text_column is not UNSET: - field_dict["seed_text_column"] = seed_text_column - if label_column is not UNSET: - field_dict["label_column"] = label_column - if date_column is not UNSET: - field_dict["date_column"] = date_column - - return field_dict - - @classmethod - def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - d = dict(src_dict) - - def _parse_file_id(data: object) -> list[str] | str: - try: - if not isinstance(data, list): - raise TypeError() - file_id_type_1 = cast(list[str], data) - - return file_id_type_1 - except (TypeError, ValueError, AttributeError, KeyError): - pass - return cast(list[str] | str, data) - - file_id = _parse_file_id(d.pop("file_id")) - - config_type = cast(Literal["CSV_SEED_GENERATOR"] | Unset, d.pop("config_type", UNSET)) - if config_type != "CSV_SEED_GENERATOR" and not isinstance(config_type, Unset): - raise ValueError(f"config_type must match const 'CSV_SEED_GENERATOR', got '{config_type}'") - - def _parse_seed_text_column(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) - - seed_text_column = _parse_seed_text_column(d.pop("seed_text_column", UNSET)) - - def _parse_label_column(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) - - label_column = _parse_label_column(d.pop("label_column", UNSET)) - - def _parse_date_column(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) - - date_column = _parse_date_column(d.pop("date_column", UNSET)) - - csv_seed_generator = cls( - file_id=file_id, - config_type=config_type, - seed_text_column=seed_text_column, - label_column=label_column, - date_column=date_column, - ) - - csv_seed_generator.additional_properties = d - return csv_seed_generator - - @property - def additional_keys(self) -> list[str]: - return list(self.additional_properties.keys()) - - def __getitem__(self, key: str) -> Any: - return self.additional_properties[key] - - def __setitem__(self, key: str, value: Any) -> None: - self.additional_properties[key] = value - - def __delitem__(self, key: str) -> None: - del self.additional_properties[key] - - def __contains__(self, key: str) -> bool: - return key in self.additional_properties diff --git a/src/lightningrod/_generated/models/estimate_cost_request.py b/src/lightningrod/_generated/models/estimate_cost_request.py index a6ca1dc..ec63f70 100644 --- a/src/lightningrod/_generated/models/estimate_cost_request.py +++ b/src/lightningrod/_generated/models/estimate_cost_request.py @@ -9,7 +9,6 @@ from ..types import UNSET, Unset if TYPE_CHECKING: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -19,7 +18,6 @@ from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler @@ -30,15 +28,14 @@ class EstimateCostRequest: """ Attributes: - config (CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | - GdeltSeedGenerator | NewsSeedGenerator | QuestionAndLabelGenerator | QuestionGenerator | QuestionPipeline | - QuestionRenderer | TopicTreeSeedGenerator | WebSearchLabeler): + config (FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator + | NewsSeedGenerator | QuestionAndLabelGenerator | QuestionGenerator | QuestionPipeline | QuestionRenderer | + WebSearchLabeler): max_questions (int | None | Unset): """ config: ( - CsvSeedGenerator - | FileSetQuerySeedGenerator + FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator @@ -47,14 +44,12 @@ class EstimateCostRequest: | QuestionGenerator | QuestionPipeline | QuestionRenderer - | TopicTreeSeedGenerator | WebSearchLabeler ) max_questions: int | None | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -64,7 +59,6 @@ def to_dict(self) -> dict[str, Any]: from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator config: dict[str, Any] if isinstance(self.config, ForwardLookingQuestionGenerator): @@ -77,10 +71,6 @@ def to_dict(self) -> dict[str, Any]: config = self.config.to_dict() elif isinstance(self.config, NewsSeedGenerator): config = self.config.to_dict() - elif isinstance(self.config, TopicTreeSeedGenerator): - config = self.config.to_dict() - elif isinstance(self.config, CsvSeedGenerator): - config = self.config.to_dict() elif isinstance(self.config, QuestionAndLabelGenerator): config = self.config.to_dict() elif isinstance(self.config, QuestionGenerator): @@ -112,7 +102,6 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_seed_generator import FileSetSeedGenerator from ..models.forward_looking_question_generator import ForwardLookingQuestionGenerator @@ -122,7 +111,6 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: from ..models.question_generator import QuestionGenerator from ..models.question_pipeline import QuestionPipeline from ..models.question_renderer import QuestionRenderer - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler d = dict(src_dict) @@ -130,8 +118,7 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: def _parse_config( data: object, ) -> ( - CsvSeedGenerator - | FileSetQuerySeedGenerator + FileSetQuerySeedGenerator | FileSetSeedGenerator | ForwardLookingQuestionGenerator | GdeltSeedGenerator @@ -140,7 +127,6 @@ def _parse_config( | QuestionGenerator | QuestionPipeline | QuestionRenderer - | TopicTreeSeedGenerator | WebSearchLabeler ): try: @@ -186,7 +172,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_5 = TopicTreeSeedGenerator.from_dict(data) + componentsschemas_create_transform_config_type_5 = QuestionAndLabelGenerator.from_dict(data) return componentsschemas_create_transform_config_type_5 except (TypeError, ValueError, AttributeError, KeyError): @@ -194,7 +180,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_6 = CsvSeedGenerator.from_dict(data) + componentsschemas_create_transform_config_type_6 = QuestionGenerator.from_dict(data) return componentsschemas_create_transform_config_type_6 except (TypeError, ValueError, AttributeError, KeyError): @@ -202,7 +188,7 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_7 = QuestionAndLabelGenerator.from_dict(data) + componentsschemas_create_transform_config_type_7 = QuestionPipeline.from_dict(data) return componentsschemas_create_transform_config_type_7 except (TypeError, ValueError, AttributeError, KeyError): @@ -210,32 +196,16 @@ def _parse_config( try: if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_8 = QuestionGenerator.from_dict(data) + componentsschemas_create_transform_config_type_8 = QuestionRenderer.from_dict(data) return componentsschemas_create_transform_config_type_8 except (TypeError, ValueError, AttributeError, KeyError): pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_create_transform_config_type_9 = QuestionPipeline.from_dict(data) - - return componentsschemas_create_transform_config_type_9 - except (TypeError, ValueError, AttributeError, KeyError): - pass - try: - if not isinstance(data, dict): - raise TypeError() - componentsschemas_create_transform_config_type_10 = QuestionRenderer.from_dict(data) - - return componentsschemas_create_transform_config_type_10 - except (TypeError, ValueError, AttributeError, KeyError): - pass if not isinstance(data, dict): raise TypeError() - componentsschemas_create_transform_config_type_11 = WebSearchLabeler.from_dict(data) + componentsschemas_create_transform_config_type_9 = WebSearchLabeler.from_dict(data) - return componentsschemas_create_transform_config_type_11 + return componentsschemas_create_transform_config_type_9 config = _parse_config(d.pop("config")) diff --git a/src/lightningrod/_generated/models/question_pipeline.py b/src/lightningrod/_generated/models/question_pipeline.py index 9d7cdf9..bab0bf7 100644 --- a/src/lightningrod/_generated/models/question_pipeline.py +++ b/src/lightningrod/_generated/models/question_pipeline.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from ..models.big_query_seed_generator import BigQuerySeedGenerator - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_context_generator import FileSetContextGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_rag_labeler import FileSetRAGLabeler @@ -26,7 +25,6 @@ from ..models.rollout_generator import RolloutGenerator from ..models.rollout_scorer import RolloutScorer from ..models.template_question_generator import TemplateQuestionGenerator - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler @@ -39,9 +37,8 @@ class QuestionPipeline: Attributes: config_type (Literal['QUESTION_PIPELINE'] | Unset): Type of transform configuration Default: 'QUESTION_PIPELINE'. - seed_generator (BigQuerySeedGenerator | CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | - GdeltSeedGenerator | MockTransformConfig | NewsSeedGenerator | None | TopicTreeSeedGenerator | Unset): - Configuration for seed generation + seed_generator (BigQuerySeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | GdeltSeedGenerator | + MockTransformConfig | NewsSeedGenerator | None | Unset): Configuration for seed generation question_generator (ForwardLookingQuestionGenerator | MockTransformConfig | None | QuestionAndLabelGenerator | QuestionGenerator | TemplateQuestionGenerator | Unset): Configuration for question generation labeler (FileSetRAGLabeler | MockTransformConfig | None | Unset | WebSearchLabeler): Configuration for labeling. @@ -59,14 +56,12 @@ class QuestionPipeline: config_type: Literal["QUESTION_PIPELINE"] | Unset = "QUESTION_PIPELINE" seed_generator: ( BigQuerySeedGenerator - | CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | GdeltSeedGenerator | MockTransformConfig | NewsSeedGenerator | None - | TopicTreeSeedGenerator | Unset ) = UNSET question_generator: ( @@ -89,7 +84,6 @@ class QuestionPipeline: def to_dict(self) -> dict[str, Any]: from ..models.big_query_seed_generator import BigQuerySeedGenerator - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_context_generator import FileSetContextGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_rag_labeler import FileSetRAGLabeler @@ -105,7 +99,6 @@ def to_dict(self) -> dict[str, Any]: from ..models.rollout_generator import RolloutGenerator from ..models.rollout_scorer import RolloutScorer from ..models.template_question_generator import TemplateQuestionGenerator - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler config_type = self.config_type @@ -123,10 +116,6 @@ def to_dict(self) -> dict[str, Any]: seed_generator = self.seed_generator.to_dict() elif isinstance(self.seed_generator, FileSetQuerySeedGenerator): seed_generator = self.seed_generator.to_dict() - elif isinstance(self.seed_generator, TopicTreeSeedGenerator): - seed_generator = self.seed_generator.to_dict() - elif isinstance(self.seed_generator, CsvSeedGenerator): - seed_generator = self.seed_generator.to_dict() elif isinstance(self.seed_generator, MockTransformConfig): seed_generator = self.seed_generator.to_dict() else: @@ -234,7 +223,6 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: from ..models.big_query_seed_generator import BigQuerySeedGenerator - from ..models.csv_seed_generator import CsvSeedGenerator from ..models.file_set_context_generator import FileSetContextGenerator from ..models.file_set_query_seed_generator import FileSetQuerySeedGenerator from ..models.file_set_rag_labeler import FileSetRAGLabeler @@ -250,7 +238,6 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: from ..models.rollout_generator import RolloutGenerator from ..models.rollout_scorer import RolloutScorer from ..models.template_question_generator import TemplateQuestionGenerator - from ..models.topic_tree_seed_generator import TopicTreeSeedGenerator from ..models.web_search_labeler import WebSearchLabeler d = dict(src_dict) @@ -262,14 +249,12 @@ def _parse_seed_generator( data: object, ) -> ( BigQuerySeedGenerator - | CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | GdeltSeedGenerator | MockTransformConfig | NewsSeedGenerator | None - | TopicTreeSeedGenerator | Unset ): if data is None: @@ -319,37 +304,19 @@ def _parse_seed_generator( try: if not isinstance(data, dict): raise TypeError() - seed_generator_type_0_type_5 = TopicTreeSeedGenerator.from_dict(data) + seed_generator_type_0_type_5 = MockTransformConfig.from_dict(data) return seed_generator_type_0_type_5 except (TypeError, ValueError, AttributeError, KeyError): pass - try: - if not isinstance(data, dict): - raise TypeError() - seed_generator_type_0_type_6 = CsvSeedGenerator.from_dict(data) - - return seed_generator_type_0_type_6 - except (TypeError, ValueError, AttributeError, KeyError): - pass - try: - if not isinstance(data, dict): - raise TypeError() - seed_generator_type_0_type_7 = MockTransformConfig.from_dict(data) - - return seed_generator_type_0_type_7 - except (TypeError, ValueError, AttributeError, KeyError): - pass return cast( BigQuerySeedGenerator - | CsvSeedGenerator | FileSetQuerySeedGenerator | FileSetSeedGenerator | GdeltSeedGenerator | MockTransformConfig | NewsSeedGenerator | None - | TopicTreeSeedGenerator | Unset, data, ) diff --git a/src/lightningrod/_generated/models/topic_tree_seed_generator.py b/src/lightningrod/_generated/models/topic_tree_seed_generator.py deleted file mode 100644 index 29de29a..0000000 --- a/src/lightningrod/_generated/models/topic_tree_seed_generator.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Literal, TypeVar, cast - -from attrs import define as _attrs_define -from attrs import field as _attrs_field - -from ..types import UNSET, Unset - -T = TypeVar("T", bound="TopicTreeSeedGenerator") - - -@_attrs_define -class TopicTreeSeedGenerator: - """Generates diverse seeds by recursively decomposing broad topics into specific subtopics. - - Given one or more root topics, uses an LLM to break each into `tree_degree` subtopics, - then repeats `tree_depth` levels deep. The leaf paths become seeds for downstream transforms. - This produces tree_degree^tree_depth seeds per root topic, each more specific than the root. - - Example: - TopicTreeSeedGenerator(topic="AI Regulation", tree_depth=2, tree_degree=4) - # Produces 16 specific seeds like "AI Regulation → Healthcare → FDA approval of - # diagnostic algorithms" that feed into QuestionGenerator or other downstream transforms. - - Attributes: - topic (list[str] | str): Root topic(s) to decompose into subtopic trees - config_type (Literal['TOPIC_TREE_SEED_GENERATOR'] | Unset): Type of transform configuration Default: - 'TOPIC_TREE_SEED_GENERATOR'. - tree_depth (int | Unset): Levels of recursive expansion Default: 2. - tree_degree (int | Unset): Subtopics generated per node Default: 5. - model_name (str | Unset): LLM model for subtopic generation Default: 'google/gemini-3-flash-preview'. - model_system_prompt (None | str | Unset): Optional system prompt for the LLM - """ - - topic: list[str] | str - config_type: Literal["TOPIC_TREE_SEED_GENERATOR"] | Unset = "TOPIC_TREE_SEED_GENERATOR" - tree_depth: int | Unset = 2 - tree_degree: int | Unset = 5 - model_name: str | Unset = "google/gemini-3-flash-preview" - model_system_prompt: None | str | Unset = UNSET - additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) - - def to_dict(self) -> dict[str, Any]: - topic: list[str] | str - if isinstance(self.topic, list): - topic = self.topic - - else: - topic = self.topic - - config_type = self.config_type - - tree_depth = self.tree_depth - - tree_degree = self.tree_degree - - model_name = self.model_name - - model_system_prompt: None | str | Unset - if isinstance(self.model_system_prompt, Unset): - model_system_prompt = UNSET - else: - model_system_prompt = self.model_system_prompt - - field_dict: dict[str, Any] = {} - field_dict.update(self.additional_properties) - field_dict.update( - { - "topic": topic, - } - ) - if config_type is not UNSET: - field_dict["config_type"] = config_type - if tree_depth is not UNSET: - field_dict["tree_depth"] = tree_depth - if tree_degree is not UNSET: - field_dict["tree_degree"] = tree_degree - if model_name is not UNSET: - field_dict["model_name"] = model_name - if model_system_prompt is not UNSET: - field_dict["model_system_prompt"] = model_system_prompt - - return field_dict - - @classmethod - def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - d = dict(src_dict) - - def _parse_topic(data: object) -> list[str] | str: - try: - if not isinstance(data, list): - raise TypeError() - topic_type_1 = cast(list[str], data) - - return topic_type_1 - except (TypeError, ValueError, AttributeError, KeyError): - pass - return cast(list[str] | str, data) - - topic = _parse_topic(d.pop("topic")) - - config_type = cast(Literal["TOPIC_TREE_SEED_GENERATOR"] | Unset, d.pop("config_type", UNSET)) - if config_type != "TOPIC_TREE_SEED_GENERATOR" and not isinstance(config_type, Unset): - raise ValueError(f"config_type must match const 'TOPIC_TREE_SEED_GENERATOR', got '{config_type}'") - - tree_depth = d.pop("tree_depth", UNSET) - - tree_degree = d.pop("tree_degree", UNSET) - - model_name = d.pop("model_name", UNSET) - - def _parse_model_system_prompt(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) - - model_system_prompt = _parse_model_system_prompt(d.pop("model_system_prompt", UNSET)) - - topic_tree_seed_generator = cls( - topic=topic, - config_type=config_type, - tree_depth=tree_depth, - tree_degree=tree_degree, - model_name=model_name, - model_system_prompt=model_system_prompt, - ) - - topic_tree_seed_generator.additional_properties = d - return topic_tree_seed_generator - - @property - def additional_keys(self) -> list[str]: - return list(self.additional_properties.keys()) - - def __getitem__(self, key: str) -> Any: - return self.additional_properties[key] - - def __setitem__(self, key: str, value: Any) -> None: - self.additional_properties[key] = value - - def __delitem__(self, key: str) -> None: - del self.additional_properties[key] - - def __contains__(self, key: str) -> bool: - return key in self.additional_properties diff --git a/src/lightningrod/_generated/models/training_job.py b/src/lightningrod/_generated/models/training_job.py index c1a1be1..8578147 100644 --- a/src/lightningrod/_generated/models/training_job.py +++ b/src/lightningrod/_generated/models/training_job.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ..models.training_config import TrainingConfig + from ..models.training_job_model_id_by_step_type_0 import TrainingJobModelIdByStepType0 T = TypeVar("T", bound="TrainingJob") @@ -30,11 +31,11 @@ class TrainingJob: updated_at (datetime.datetime): name (None | str | Unset): model_id (None | str | Unset): + model_id_by_step (None | TrainingJobModelIdByStepType0 | Unset): reward_history (list[float] | None | Unset): current_step (int | None | Unset): total_steps (int | None | Unset): cost_dollars (float | None | Unset): - dataset_hf_repo (None | str | Unset): error_message (None | str | Unset): """ @@ -46,15 +47,17 @@ class TrainingJob: updated_at: datetime.datetime name: None | str | Unset = UNSET model_id: None | str | Unset = UNSET + model_id_by_step: None | TrainingJobModelIdByStepType0 | Unset = UNSET reward_history: list[float] | None | Unset = UNSET current_step: int | None | Unset = UNSET total_steps: int | None | Unset = UNSET cost_dollars: float | None | Unset = UNSET - dataset_hf_repo: None | str | Unset = UNSET error_message: None | str | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: + from ..models.training_job_model_id_by_step_type_0 import TrainingJobModelIdByStepType0 + id = self.id organization_id = self.organization_id @@ -79,6 +82,14 @@ def to_dict(self) -> dict[str, Any]: else: model_id = self.model_id + model_id_by_step: dict[str, Any] | None | Unset + if isinstance(self.model_id_by_step, Unset): + model_id_by_step = UNSET + elif isinstance(self.model_id_by_step, TrainingJobModelIdByStepType0): + model_id_by_step = self.model_id_by_step.to_dict() + else: + model_id_by_step = self.model_id_by_step + reward_history: list[float] | None | Unset if isinstance(self.reward_history, Unset): reward_history = UNSET @@ -106,12 +117,6 @@ def to_dict(self) -> dict[str, Any]: else: cost_dollars = self.cost_dollars - dataset_hf_repo: None | str | Unset - if isinstance(self.dataset_hf_repo, Unset): - dataset_hf_repo = UNSET - else: - dataset_hf_repo = self.dataset_hf_repo - error_message: None | str | Unset if isinstance(self.error_message, Unset): error_message = UNSET @@ -134,6 +139,8 @@ def to_dict(self) -> dict[str, Any]: field_dict["name"] = name if model_id is not UNSET: field_dict["model_id"] = model_id + if model_id_by_step is not UNSET: + field_dict["model_id_by_step"] = model_id_by_step if reward_history is not UNSET: field_dict["reward_history"] = reward_history if current_step is not UNSET: @@ -142,8 +149,6 @@ def to_dict(self) -> dict[str, Any]: field_dict["total_steps"] = total_steps if cost_dollars is not UNSET: field_dict["cost_dollars"] = cost_dollars - if dataset_hf_repo is not UNSET: - field_dict["dataset_hf_repo"] = dataset_hf_repo if error_message is not UNSET: field_dict["error_message"] = error_message @@ -152,6 +157,7 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: from ..models.training_config import TrainingConfig + from ..models.training_job_model_id_by_step_type_0 import TrainingJobModelIdByStepType0 d = dict(src_dict) id = d.pop("id") @@ -184,6 +190,23 @@ def _parse_model_id(data: object) -> None | str | Unset: model_id = _parse_model_id(d.pop("model_id", UNSET)) + def _parse_model_id_by_step(data: object) -> None | TrainingJobModelIdByStepType0 | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, dict): + raise TypeError() + model_id_by_step_type_0 = TrainingJobModelIdByStepType0.from_dict(data) + + return model_id_by_step_type_0 + except (TypeError, ValueError, AttributeError, KeyError): + pass + return cast(None | TrainingJobModelIdByStepType0 | Unset, data) + + model_id_by_step = _parse_model_id_by_step(d.pop("model_id_by_step", UNSET)) + def _parse_reward_history(data: object) -> list[float] | None | Unset: if data is None: return data @@ -228,15 +251,6 @@ def _parse_cost_dollars(data: object) -> float | None | Unset: cost_dollars = _parse_cost_dollars(d.pop("cost_dollars", UNSET)) - def _parse_dataset_hf_repo(data: object) -> None | str | Unset: - if data is None: - return data - if isinstance(data, Unset): - return data - return cast(None | str | Unset, data) - - dataset_hf_repo = _parse_dataset_hf_repo(d.pop("dataset_hf_repo", UNSET)) - def _parse_error_message(data: object) -> None | str | Unset: if data is None: return data @@ -255,11 +269,11 @@ def _parse_error_message(data: object) -> None | str | Unset: updated_at=updated_at, name=name, model_id=model_id, + model_id_by_step=model_id_by_step, reward_history=reward_history, current_step=current_step, total_steps=total_steps, cost_dollars=cost_dollars, - dataset_hf_repo=dataset_hf_repo, error_message=error_message, ) diff --git a/src/lightningrod/_generated/models/validation_error.py b/src/lightningrod/_generated/models/validation_error.py index cb0708f..6df5a45 100644 --- a/src/lightningrod/_generated/models/validation_error.py +++ b/src/lightningrod/_generated/models/validation_error.py @@ -1,11 +1,17 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from attrs import define as _attrs_define from attrs import field as _attrs_field +from ..types import UNSET, Unset + +if TYPE_CHECKING: + from ..models.validation_error_context import ValidationErrorContext + + T = TypeVar("T", bound="ValidationError") @@ -16,11 +22,15 @@ class ValidationError: loc (list[int | str]): msg (str): type_ (str): + input_ (Any | Unset): + ctx (ValidationErrorContext | Unset): """ loc: list[int | str] msg: str type_: str + input_: Any | Unset = UNSET + ctx: ValidationErrorContext | Unset = UNSET additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: @@ -34,6 +44,12 @@ def to_dict(self) -> dict[str, Any]: type_ = self.type_ + input_ = self.input_ + + ctx: dict[str, Any] | Unset = UNSET + if not isinstance(self.ctx, Unset): + ctx = self.ctx.to_dict() + field_dict: dict[str, Any] = {} field_dict.update(self.additional_properties) field_dict.update( @@ -43,11 +59,17 @@ def to_dict(self) -> dict[str, Any]: "type": type_, } ) + if input_ is not UNSET: + field_dict["input"] = input_ + if ctx is not UNSET: + field_dict["ctx"] = ctx return field_dict @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.validation_error_context import ValidationErrorContext + d = dict(src_dict) loc = [] _loc = d.pop("loc") @@ -64,10 +86,21 @@ def _parse_loc_item(data: object) -> int | str: type_ = d.pop("type") + input_ = d.pop("input", UNSET) + + _ctx = d.pop("ctx", UNSET) + ctx: ValidationErrorContext | Unset + if isinstance(_ctx, Unset): + ctx = UNSET + else: + ctx = ValidationErrorContext.from_dict(_ctx) + validation_error = cls( loc=loc, msg=msg, type_=type_, + input_=input_, + ctx=ctx, ) validation_error.additional_properties = d