From c294ba89eb7560463b31cb25f686813775f41f95 Mon Sep 17 00:00:00 2001 From: timepointai Date: Mon, 16 Mar 2026 10:08:55 -0600 Subject: [PATCH] feat: add Stability AI SD3.5 provider for permissive image generation Add Stability AI as a new provider for image generation, enabling the permissive/free-distillable pipeline to generate images using SD3.5 Large which allows downstream distillation. - Create StabilityProvider with SD3.5 REST API integration - Register stability-ai/sd3.5-large in image model registry - Add STABILITY to ProviderType enum and Settings - Add stability-ai/ to PERMISSIVE_PREFIXES in model policy - Update FREE_DISTILLABLE preset with SD3.5 image model - Wire Stability provider into LLM router with fallback to OpenRouter --- app/config.py | 36 +++- app/core/llm_router.py | 34 +++- app/core/model_capabilities.py | 32 +++- app/core/model_policy.py | 3 + app/core/providers/__init__.py | 2 + app/core/providers/stability.py | 313 ++++++++++++++++++++++++++++++++ 6 files changed, 404 insertions(+), 16 deletions(-) create mode 100644 app/core/providers/stability.py diff --git a/app/config.py b/app/config.py index 27c289a..c1b76a7 100644 --- a/app/config.py +++ b/app/config.py @@ -30,6 +30,7 @@ class ProviderType(str, Enum): GOOGLE = "google" OPENROUTER = "openrouter" + STABILITY = "stability" class ParallelismMode(str, Enum): @@ -100,6 +101,10 @@ class VerifiedModels: "gemini-3-pro-image-preview", # Nano Banana Pro - 2K/4K, best quality ] + STABILITY_IMAGE = [ + "stability-ai/sd3.5-large", # SD3.5 Large - distillation-permissive + ] + # OpenRouter API (via openrouter.ai) # These work with OPENROUTER_API_KEY OPENROUTER_TEXT = [ @@ -147,7 +152,7 @@ def is_verified_text_model(cls, model: str) -> bool: @classmethod def is_verified_image_model(cls, model: str) -> bool: """Check if an image model is verified.""" - return model in cls.GOOGLE_IMAGE + return model in cls.GOOGLE_IMAGE or model in cls.STABILITY_IMAGE @classmethod def get_safe_text_model(cls, provider: "ProviderType") -> str: @@ -177,6 +182,9 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool: if provider == ProviderType.GOOGLE: if model in cls.GOOGLE_TEXT or model in cls.GOOGLE_IMAGE: return True + elif provider == ProviderType.STABILITY: + if model in cls.STABILITY_IMAGE: + return True else: if model in cls.OPENROUTER_TEXT: return True @@ -252,15 +260,15 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool: }, QualityPreset.FREE_DISTILLABLE: { "name": "Free Distillable", - "description": "Free models with distillation rights — $0 cost, text-only (no image gen)", + "description": "Free distillable models — text via OpenRouter, images via Stability AI SD3.5", "text_model": "openrouter/hunter-alpha", "judge_model": "openrouter/healer-alpha", - "image_model": None, # No free distillable image models available yet - "image_provider": None, + "image_model": "stability-ai/sd3.5-large", + "image_provider": ProviderType.STABILITY, "text_provider": ProviderType.OPENROUTER, "max_tokens": 4096, "thinking_level": None, - "image_supported": False, # Text-only mode + "image_supported": True, }, } @@ -286,6 +294,10 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool: "rpm": 30, # Conservative default (varies by model) "max_concurrent": 5, # Safe concurrent calls }, + ProviderType.STABILITY: { + "rpm": 10, # Stability AI conservative default + "max_concurrent": 3, # Image gen is resource-heavy + }, } # Tier-based concurrent limits for each parallelism mode @@ -377,6 +389,10 @@ class Settings(BaseSettings): default=None, description="OpenRouter API key", ) + STABILITY_API_KEY: str | None = Field( + default=None, + description="Stability AI API key for SD3.5 image generation", + ) # Provider Selection PRIMARY_PROVIDER: ProviderType = Field( @@ -527,13 +543,13 @@ def validate_providers(self) -> "Settings": """ # Soft validation - just track if any providers are available # The app will start but providers will be marked as unavailable - self._has_any_provider = bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY) + self._has_any_provider = bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY or self.STABILITY_API_KEY) return self @property def has_any_provider(self) -> bool: """Check if any provider API key is configured.""" - return bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY) + return bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY or self.STABILITY_API_KEY) @property def is_production(self) -> bool: @@ -574,6 +590,8 @@ def has_provider(self, provider: ProviderType) -> bool: return bool(self.GOOGLE_API_KEY) elif provider == ProviderType.OPENROUTER: return bool(self.OPENROUTER_API_KEY) + elif provider == ProviderType.STABILITY: + return bool(self.STABILITY_API_KEY) return False def get_api_key(self, provider: ProviderType) -> str: @@ -596,6 +614,10 @@ def get_api_key(self, provider: ProviderType) -> str: if not self.OPENROUTER_API_KEY: raise ValueError("OPENROUTER_API_KEY not configured") return self.OPENROUTER_API_KEY + elif provider == ProviderType.STABILITY: + if not self.STABILITY_API_KEY: + raise ValueError("STABILITY_API_KEY not configured") + return self.STABILITY_API_KEY raise ValueError(f"Unknown provider: {provider}") def get_model_config(self) -> dict[str, Any]: diff --git a/app/core/llm_router.py b/app/core/llm_router.py index 5c52d8e..78d4cba 100644 --- a/app/core/llm_router.py +++ b/app/core/llm_router.py @@ -56,6 +56,7 @@ ) from app.core.providers.google import GoogleProvider from app.core.providers.openrouter import OpenRouterProvider +from app.core.providers.stability import StabilityProvider from app.core.rate_limiter import acquire_rate_limit, get_tier_from_model logger = logging.getLogger(__name__) @@ -281,6 +282,12 @@ def _init_providers(self, settings: Any) -> None: ) logger.info("Initialized OpenRouter provider") + if settings.has_provider(ProviderType.STABILITY): + self.providers[ProviderType.STABILITY] = StabilityProvider( + api_key=settings.STABILITY_API_KEY + ) + logger.info("Initialized Stability AI provider") + def _get_provider(self, provider_type: ProviderType) -> LLMProvider: """Get provider instance by type. @@ -941,11 +948,23 @@ async def generate_image( is_permissive = bool( self._model_policy and self._model_policy.lower() == "permissive" ) - if is_permissive and ProviderType.OPENROUTER in self.providers: - # Permissive mode: always use OpenRouter for images (Google-free) + if is_permissive and ProviderType.STABILITY in self.providers: + # Permissive mode: prefer Stability AI for distillation-safe images + image_provider = ProviderType.STABILITY + elif is_permissive and ProviderType.OPENROUTER in self.providers: + # Permissive mode fallback: use OpenRouter (Google-free) image_provider = ProviderType.OPENROUTER elif self._preset_config and "image_provider" in self._preset_config: - image_provider = self._preset_config["image_provider"] + preset_image_provider = self._preset_config["image_provider"] + # Use preset provider if available, otherwise fall through + if preset_image_provider and preset_image_provider in self.providers: + image_provider = preset_image_provider + elif ProviderType.GOOGLE in self.providers: + image_provider = ProviderType.GOOGLE + elif ProviderType.OPENROUTER in self.providers: + image_provider = ProviderType.OPENROUTER + else: + image_provider = self.config.primary elif ProviderType.GOOGLE in self.providers: image_provider = ProviderType.GOOGLE elif ProviderType.OPENROUTER in self.providers: @@ -973,7 +992,6 @@ async def generate_image( should_fallback = ( image_provider != ProviderType.OPENROUTER and ProviderType.OPENROUTER in self.providers - and not is_permissive # Already on OpenRouter in permissive mode ) if not should_fallback: @@ -984,11 +1002,11 @@ async def generate_image( ) from e # Log appropriately based on error type - image_fallback = get_image_fallback_model() + image_fallback = get_image_fallback_model(permissive_only=is_permissive) if isinstance(e, QuotaExhaustedError): logger.warning( - f"Google quota exhausted - immediately falling back to OpenRouter " - f"with {image_fallback}" + f"Quota exhausted on {image_provider.value} - falling back to " + f"OpenRouter with {image_fallback}" ) else: logger.warning( @@ -998,7 +1016,7 @@ async def generate_image( try: fallback_provider = self._get_provider(ProviderType.OPENROUTER) - # Remove Google-specific params that may not work with Flux + # Remove provider-specific params that may not work with fallback fallback_kwargs = { k: v for k, v in image_kwargs.items() if k not in ("image_size",) # Flux uses different params diff --git a/app/core/model_capabilities.py b/app/core/model_capabilities.py index 3fb1cc2..38ef9ff 100644 --- a/app/core/model_capabilities.py +++ b/app/core/model_capabilities.py @@ -25,6 +25,7 @@ class ImageModelType(str, Enum): GEMINI_NATIVE = "gemini_native" # Gemini with native image gen (Nano Banana) GEMINI_PRO = "gemini_pro" # Gemini 3 Pro Image (Nano Banana Pro) IMAGEN = "imagen" # Legacy Imagen API + STABILITY = "stability" # Stability AI SD3.5 REST API @dataclass @@ -95,6 +96,20 @@ class ImageModelConfig: timeout_multiplier=3.0, # Higher quality takes longer notes="Preview model, best quality. Supports 1K/2K/4K.", ), + # Stability AI SD3.5 Large - Distillation-permissive + "stability-ai/sd3.5-large": ImageModelConfig( + model_id="stability-ai/sd3.5-large", + model_type=ImageModelType.STABILITY, + response_modalities=[], # Not applicable - uses Stability REST API + supports_image_size=False, + supported_sizes=[], + max_resolution=1024, + supports_aspect_ratio=True, + use_camel_case_params=False, # Stability uses snake_case + fallback_models=["black-forest-labs/flux.2-pro"], + timeout_multiplier=2.0, + notes="Stability AI SD3.5 Large. Distillation-permissive license. Uses REST API.", + ), # Legacy Imagen - Uses different API "imagen-3.0-generate-002": ImageModelConfig( model_id="imagen-3.0-generate-002", @@ -232,6 +247,19 @@ def is_imagen_model(model_id: str) -> bool: return config.model_type == ImageModelType.IMAGEN +def is_stability_model(model_id: str) -> bool: + """Check if model uses Stability AI API. + + Args: + model_id: The model identifier. + + Returns: + True if model uses Stability AI REST API. + """ + config = get_image_model_config(model_id) + return config.model_type == ImageModelType.STABILITY + + def is_gemini_image_model(model_id: str) -> bool: """Check if model uses Gemini image generation (generate_content with IMAGE modality). @@ -711,7 +739,9 @@ def infer_provider_from_model_id(model_id: str) -> str: return TEXT_MODEL_REGISTRY[model_id].provider # Infer from model ID pattern - if "/" in model_id: + if model_id.startswith("stability-ai/"): + return "stability" + elif "/" in model_id: # OpenRouter format: provider/model-name return "openrouter" elif model_id.startswith("gemini-"): diff --git a/app/core/model_policy.py b/app/core/model_policy.py index 4135983..678661a 100644 --- a/app/core/model_policy.py +++ b/app/core/model_policy.py @@ -17,6 +17,7 @@ "nousresearch/", # Hermes family (open-weight) "black-forest-labs/", # FLUX open-weight image models "openrouter/", # OpenRouter free distillable models (Hunter, Healer) + "stability-ai/", # Stability AI SD3.5 (distillation-permissive) ) # Google-native model prefixes (always restricted) @@ -31,6 +32,8 @@ def derive_model_provider(model_id: str | None) -> str: if not model_id: return "unknown" lower = model_id.lower() + if lower.startswith("stability-ai/"): + return "stability" if any(lower.startswith(p) for p in GOOGLE_MODEL_PREFIXES): return "google" if any(lower.startswith(p) for p in OPENROUTER_PREFIXES): diff --git a/app/core/providers/__init__.py b/app/core/providers/__init__.py index e78eea7..8a662b7 100644 --- a/app/core/providers/__init__.py +++ b/app/core/providers/__init__.py @@ -19,6 +19,7 @@ # Provider implementations - lazy imports to avoid circular issues from app.core.providers.google import GoogleProvider from app.core.providers.openrouter import OpenRouterProvider +from app.core.providers.stability import StabilityProvider __all__ = [ # Base classes @@ -34,4 +35,5 @@ # Implementations "GoogleProvider", "OpenRouterProvider", + "StabilityProvider", ] diff --git a/app/core/providers/stability.py b/app/core/providers/stability.py new file mode 100644 index 0000000..dafd107 --- /dev/null +++ b/app/core/providers/stability.py @@ -0,0 +1,313 @@ +"""Stability AI REST API provider implementation. + +This module provides integration with Stability AI's image generation models +via the REST API. Supports SD3.5 Large for permissive/distillable image generation. + +Stability AI API docs: https://platform.stability.ai/docs/api-reference + +Examples: + >>> from app.core.providers.stability import StabilityProvider + >>> provider = StabilityProvider(api_key="sk-...") + >>> response = await provider.generate_image( + ... prompt="A sunset over mountains", + ... model="stability-ai/sd3.5-large" + ... ) + +Tests: + - tests/unit/test_providers.py::test_stability_provider_init + - tests/unit/test_providers.py::test_stability_provider_generate_image +""" + +import base64 +import logging +import time +from typing import Any, TypeVar + +import httpx +from pydantic import BaseModel + +from app.config import ProviderType +from app.core.providers.base import ( + AuthenticationError, + LLMProvider, + LLMResponse, + ProviderError, + RateLimitError, +) + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + +# Stability AI API configuration +STABILITY_API_BASE = "https://api.stability.ai" +STABILITY_SD3_ENDPOINT = f"{STABILITY_API_BASE}/v2beta/stable-image/generate/sd3" + +# Model ID to Stability API model parameter mapping +STABILITY_MODEL_MAP = { + "stability-ai/sd3.5-large": "sd3.5-large", +} + +# Default generation parameters +DEFAULT_ASPECT_RATIO = "16:9" +DEFAULT_OUTPUT_FORMAT = "png" + + +class StabilityProvider(LLMProvider): + """Stability AI REST API provider for image generation. + + Uses the Stability AI REST API for SD3.5 image generation. + SD3.5 Large allows downstream distillation, making it suitable + for the permissive/free-distillable pipeline. + + Attributes: + provider_type: ProviderType.STABILITY + api_key: Stability AI API key + timeout: Request timeout in seconds + + Available Models: + - stability-ai/sd3.5-large: SD3.5 Large (distillation-permissive) + + Examples: + >>> provider = StabilityProvider(api_key="sk-...") + >>> response = await provider.generate_image( + ... prompt="A photorealistic landscape", + ... model="stability-ai/sd3.5-large", + ... aspect_ratio="16:9" + ... ) + """ + + provider_type = ProviderType.STABILITY + + DEFAULT_TIMEOUT = 120 # Image generation can take a while + + def __init__(self, api_key: str, timeout: float = DEFAULT_TIMEOUT) -> None: + """Initialize Stability AI provider. + + Args: + api_key: Stability AI API key (STABILITY_API_KEY). + timeout: Request timeout in seconds (default: 120). + """ + super().__init__(api_key) + self.timeout = timeout + self._client: httpx.AsyncClient | None = None + + @property + def client(self) -> httpx.AsyncClient: + """Get httpx async client (lazy initialization).""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient( + timeout=self.timeout, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Accept": "application/json", + }, + ) + return self._client + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client is not None and not self._client.is_closed: + await self._client.aclose() + self._client = None + + def _handle_error(self, response: httpx.Response) -> None: + """Convert HTTP errors to provider errors. + + Args: + response: The HTTP response. + + Raises: + AuthenticationError: For 401/403 errors. + RateLimitError: For 429 errors. + ProviderError: For other errors. + """ + if response.status_code in (401, 403): + raise AuthenticationError(ProviderType.STABILITY) + elif response.status_code == 429: + retry_after = response.headers.get("Retry-After") + raise RateLimitError( + ProviderType.STABILITY, + retry_after=int(retry_after) if retry_after else None, + ) + else: + try: + error_data = response.json() + message = error_data.get("message", response.text) + except Exception: + message = response.text + + raise ProviderError( + message=f"Stability AI error: {message}", + provider=ProviderType.STABILITY, + status_code=response.status_code, + retryable=response.status_code >= 500, + ) + + async def call_text( + self, + prompt: str, + model: str, + response_model: type[T] | None = None, + **kwargs: Any, + ) -> LLMResponse[T] | LLMResponse[str]: + """Text generation is not supported by Stability AI. + + Raises: + ProviderError: Always, as Stability AI is image-only. + """ + raise ProviderError( + message="Stability AI does not support text generation", + provider=ProviderType.STABILITY, + retryable=False, + ) + + async def generate_image( + self, + prompt: str, + model: str, + **kwargs: Any, + ) -> LLMResponse[str]: + """Generate an image using Stability AI SD3.5 API. + + Sends a multipart/form-data request to the Stability AI REST API + and returns the generated image as base64-encoded PNG data. + + Args: + prompt: The image generation prompt. + model: Model ID (e.g., "stability-ai/sd3.5-large"). + **kwargs: Additional parameters: + - aspect_ratio: Image aspect ratio ("1:1", "16:9", "3:2", etc.) + - output_format: Output format ("png", "jpeg", "webp") + - negative_prompt: Negative prompt for things to avoid + + Returns: + LLMResponse containing base64-encoded image data. + + Raises: + AuthenticationError: If API key is invalid. + RateLimitError: If rate limit is hit. + ProviderError: If the API call fails. + + Examples: + >>> response = await provider.generate_image( + ... prompt="A sunset over mountains", + ... model="stability-ai/sd3.5-large", + ... aspect_ratio="16:9" + ... ) + """ + start_time = time.perf_counter() + + # Map our model ID to Stability API model parameter + api_model = STABILITY_MODEL_MAP.get(model, "sd3.5-large") + + # Build multipart form data + aspect_ratio = kwargs.get("aspect_ratio", DEFAULT_ASPECT_RATIO) + output_format = kwargs.get("output_format", DEFAULT_OUTPUT_FORMAT) + + form_data = { + "prompt": prompt, + "model": api_model, + "output_format": output_format, + "aspect_ratio": aspect_ratio, + } + + # Add optional negative prompt + if "negative_prompt" in kwargs: + form_data["negative_prompt"] = kwargs["negative_prompt"] + + logger.debug( + f"Calling Stability AI: model={api_model}, " + f"aspect_ratio={aspect_ratio}, format={output_format}" + ) + + try: + response = await self.client.post( + STABILITY_SD3_ENDPOINT, + data=form_data, + ) + + if response.status_code != 200: + self._handle_error(response) + + latency_ms = int((time.perf_counter() - start_time) * 1000) + + # Parse JSON response containing base64 image + data = response.json() + image_b64 = data.get("image") + + if not image_b64: + raise ProviderError( + message=f"No image in Stability AI response: {str(data)[:500]}", + provider=ProviderType.STABILITY, + retryable=False, + ) + + logger.info( + f"Stability AI image generated in {latency_ms}ms " + f"(model={api_model}, format={output_format})" + ) + + return LLMResponse( + content=image_b64, + model=model, + provider=self.provider_type, + latency_ms=latency_ms, + metadata={ + "mime_type": f"image/{output_format}", + "aspect_ratio": aspect_ratio, + }, + ) + + except httpx.HTTPError as e: + logger.error(f"Stability AI HTTP error: {e}") + raise ProviderError( + message=str(e), + provider=ProviderType.STABILITY, + retryable=True, + ) from e + + async def analyze_image( + self, + image: str, + prompt: str, + model: str, + **kwargs: Any, + ) -> LLMResponse[dict[str, Any]]: + """Image analysis is not supported by Stability AI. + + Raises: + ProviderError: Always, as Stability AI is generation-only. + """ + raise ProviderError( + message="Stability AI does not support image analysis", + provider=ProviderType.STABILITY, + retryable=False, + ) + + async def health_check(self) -> bool: + """Check if Stability AI provider is accessible. + + Makes a minimal request to verify the API key is valid. + We use a very short prompt to minimize credit usage. + + Returns: + bool: True if provider is healthy. + """ + try: + # Just check that auth works by making a small request + # We'll catch errors - any non-auth error means the API is reachable + response = await self.client.post( + STABILITY_SD3_ENDPOINT, + data={ + "prompt": "test", + "model": "sd3.5-large", + "output_format": "png", + }, + ) + # 200 = works, 402 = payment required (but API is reachable) + # Only 401/403 means unhealthy + return response.status_code not in (401, 403) + except Exception as e: + logger.warning(f"Stability AI health check failed: {e}") + return False