Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ProviderType(str, Enum):

GOOGLE = "google"
OPENROUTER = "openrouter"
STABILITY = "stability"


class ParallelismMode(str, Enum):
Expand Down Expand Up @@ -100,6 +101,10 @@ class VerifiedModels:
"gemini-3-pro-image-preview", # Nano Banana Pro - 2K/4K, best quality
]

STABILITY_IMAGE = [
"stability-ai/sd3.5-large", # SD3.5 Large - distillation-permissive
]

# OpenRouter API (via openrouter.ai)
# These work with OPENROUTER_API_KEY
OPENROUTER_TEXT = [
Expand Down Expand Up @@ -147,7 +152,7 @@ def is_verified_text_model(cls, model: str) -> bool:
@classmethod
def is_verified_image_model(cls, model: str) -> bool:
"""Check if an image model is verified."""
return model in cls.GOOGLE_IMAGE
return model in cls.GOOGLE_IMAGE or model in cls.STABILITY_IMAGE

@classmethod
def get_safe_text_model(cls, provider: "ProviderType") -> str:
Expand Down Expand Up @@ -177,6 +182,9 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool:
if provider == ProviderType.GOOGLE:
if model in cls.GOOGLE_TEXT or model in cls.GOOGLE_IMAGE:
return True
elif provider == ProviderType.STABILITY:
if model in cls.STABILITY_IMAGE:
return True
else:
if model in cls.OPENROUTER_TEXT:
return True
Expand Down Expand Up @@ -252,15 +260,15 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool:
},
QualityPreset.FREE_DISTILLABLE: {
"name": "Free Distillable",
"description": "Free models with distillation rights — $0 cost, text-only (no image gen)",
"description": "Free distillable models — text via OpenRouter, images via Stability AI SD3.5",
"text_model": "openrouter/hunter-alpha",
"judge_model": "openrouter/healer-alpha",
"image_model": None, # No free distillable image models available yet
"image_provider": None,
"image_model": "stability-ai/sd3.5-large",
"image_provider": ProviderType.STABILITY,
"text_provider": ProviderType.OPENROUTER,
"max_tokens": 4096,
"thinking_level": None,
"image_supported": False, # Text-only mode
"image_supported": True,
},
}

Expand All @@ -286,6 +294,10 @@ def is_verified_or_available(cls, model: str, provider: "ProviderType") -> bool:
"rpm": 30, # Conservative default (varies by model)
"max_concurrent": 5, # Safe concurrent calls
},
ProviderType.STABILITY: {
"rpm": 10, # Stability AI conservative default
"max_concurrent": 3, # Image gen is resource-heavy
},
}

# Tier-based concurrent limits for each parallelism mode
Expand Down Expand Up @@ -377,6 +389,10 @@ class Settings(BaseSettings):
default=None,
description="OpenRouter API key",
)
STABILITY_API_KEY: str | None = Field(
default=None,
description="Stability AI API key for SD3.5 image generation",
)

# Provider Selection
PRIMARY_PROVIDER: ProviderType = Field(
Expand Down Expand Up @@ -527,13 +543,13 @@ def validate_providers(self) -> "Settings":
"""
# Soft validation - just track if any providers are available
# The app will start but providers will be marked as unavailable
self._has_any_provider = bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY)
self._has_any_provider = bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY or self.STABILITY_API_KEY)
return self

@property
def has_any_provider(self) -> bool:
"""Check if any provider API key is configured."""
return bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY)
return bool(self.GOOGLE_API_KEY or self.OPENROUTER_API_KEY or self.STABILITY_API_KEY)

@property
def is_production(self) -> bool:
Expand Down Expand Up @@ -574,6 +590,8 @@ def has_provider(self, provider: ProviderType) -> bool:
return bool(self.GOOGLE_API_KEY)
elif provider == ProviderType.OPENROUTER:
return bool(self.OPENROUTER_API_KEY)
elif provider == ProviderType.STABILITY:
return bool(self.STABILITY_API_KEY)
return False

def get_api_key(self, provider: ProviderType) -> str:
Expand All @@ -596,6 +614,10 @@ def get_api_key(self, provider: ProviderType) -> str:
if not self.OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY not configured")
return self.OPENROUTER_API_KEY
elif provider == ProviderType.STABILITY:
if not self.STABILITY_API_KEY:
raise ValueError("STABILITY_API_KEY not configured")
return self.STABILITY_API_KEY
raise ValueError(f"Unknown provider: {provider}")

def get_model_config(self) -> dict[str, Any]:
Expand Down
34 changes: 26 additions & 8 deletions app/core/llm_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from app.core.providers.google import GoogleProvider
from app.core.providers.openrouter import OpenRouterProvider
from app.core.providers.stability import StabilityProvider
from app.core.rate_limiter import acquire_rate_limit, get_tier_from_model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -281,6 +282,12 @@ def _init_providers(self, settings: Any) -> None:
)
logger.info("Initialized OpenRouter provider")

if settings.has_provider(ProviderType.STABILITY):
self.providers[ProviderType.STABILITY] = StabilityProvider(
api_key=settings.STABILITY_API_KEY
)
logger.info("Initialized Stability AI provider")

def _get_provider(self, provider_type: ProviderType) -> LLMProvider:
"""Get provider instance by type.

Expand Down Expand Up @@ -941,11 +948,23 @@ async def generate_image(
is_permissive = bool(
self._model_policy and self._model_policy.lower() == "permissive"
)
if is_permissive and ProviderType.OPENROUTER in self.providers:
# Permissive mode: always use OpenRouter for images (Google-free)
if is_permissive and ProviderType.STABILITY in self.providers:
# Permissive mode: prefer Stability AI for distillation-safe images
image_provider = ProviderType.STABILITY
elif is_permissive and ProviderType.OPENROUTER in self.providers:
# Permissive mode fallback: use OpenRouter (Google-free)
image_provider = ProviderType.OPENROUTER
elif self._preset_config and "image_provider" in self._preset_config:
image_provider = self._preset_config["image_provider"]
preset_image_provider = self._preset_config["image_provider"]
# Use preset provider if available, otherwise fall through
if preset_image_provider and preset_image_provider in self.providers:
image_provider = preset_image_provider
elif ProviderType.GOOGLE in self.providers:
image_provider = ProviderType.GOOGLE
elif ProviderType.OPENROUTER in self.providers:
image_provider = ProviderType.OPENROUTER
else:
image_provider = self.config.primary
elif ProviderType.GOOGLE in self.providers:
image_provider = ProviderType.GOOGLE
elif ProviderType.OPENROUTER in self.providers:
Expand Down Expand Up @@ -973,7 +992,6 @@ async def generate_image(
should_fallback = (
image_provider != ProviderType.OPENROUTER
and ProviderType.OPENROUTER in self.providers
and not is_permissive # Already on OpenRouter in permissive mode
)

if not should_fallback:
Expand All @@ -984,11 +1002,11 @@ async def generate_image(
) from e

# Log appropriately based on error type
image_fallback = get_image_fallback_model()
image_fallback = get_image_fallback_model(permissive_only=is_permissive)
if isinstance(e, QuotaExhaustedError):
logger.warning(
f"Google quota exhausted - immediately falling back to OpenRouter "
f"with {image_fallback}"
f"Quota exhausted on {image_provider.value} - falling back to "
f"OpenRouter with {image_fallback}"
)
else:
logger.warning(
Expand All @@ -998,7 +1016,7 @@ async def generate_image(

try:
fallback_provider = self._get_provider(ProviderType.OPENROUTER)
# Remove Google-specific params that may not work with Flux
# Remove provider-specific params that may not work with fallback
fallback_kwargs = {
k: v for k, v in image_kwargs.items()
if k not in ("image_size",) # Flux uses different params
Expand Down
32 changes: 31 additions & 1 deletion app/core/model_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ImageModelType(str, Enum):
GEMINI_NATIVE = "gemini_native" # Gemini with native image gen (Nano Banana)
GEMINI_PRO = "gemini_pro" # Gemini 3 Pro Image (Nano Banana Pro)
IMAGEN = "imagen" # Legacy Imagen API
STABILITY = "stability" # Stability AI SD3.5 REST API


@dataclass
Expand Down Expand Up @@ -95,6 +96,20 @@ class ImageModelConfig:
timeout_multiplier=3.0, # Higher quality takes longer
notes="Preview model, best quality. Supports 1K/2K/4K.",
),
# Stability AI SD3.5 Large - Distillation-permissive
"stability-ai/sd3.5-large": ImageModelConfig(
model_id="stability-ai/sd3.5-large",
model_type=ImageModelType.STABILITY,
response_modalities=[], # Not applicable - uses Stability REST API
supports_image_size=False,
supported_sizes=[],
max_resolution=1024,
supports_aspect_ratio=True,
use_camel_case_params=False, # Stability uses snake_case
fallback_models=["black-forest-labs/flux.2-pro"],
timeout_multiplier=2.0,
notes="Stability AI SD3.5 Large. Distillation-permissive license. Uses REST API.",
),
# Legacy Imagen - Uses different API
"imagen-3.0-generate-002": ImageModelConfig(
model_id="imagen-3.0-generate-002",
Expand Down Expand Up @@ -232,6 +247,19 @@ def is_imagen_model(model_id: str) -> bool:
return config.model_type == ImageModelType.IMAGEN


def is_stability_model(model_id: str) -> bool:
"""Check if model uses Stability AI API.

Args:
model_id: The model identifier.

Returns:
True if model uses Stability AI REST API.
"""
config = get_image_model_config(model_id)
return config.model_type == ImageModelType.STABILITY


def is_gemini_image_model(model_id: str) -> bool:
"""Check if model uses Gemini image generation (generate_content with IMAGE modality).

Expand Down Expand Up @@ -711,7 +739,9 @@ def infer_provider_from_model_id(model_id: str) -> str:
return TEXT_MODEL_REGISTRY[model_id].provider

# Infer from model ID pattern
if "/" in model_id:
if model_id.startswith("stability-ai/"):
return "stability"
elif "/" in model_id:
# OpenRouter format: provider/model-name
return "openrouter"
elif model_id.startswith("gemini-"):
Expand Down
3 changes: 3 additions & 0 deletions app/core/model_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"nousresearch/", # Hermes family (open-weight)
"black-forest-labs/", # FLUX open-weight image models
"openrouter/", # OpenRouter free distillable models (Hunter, Healer)
"stability-ai/", # Stability AI SD3.5 (distillation-permissive)
)

# Google-native model prefixes (always restricted)
Expand All @@ -31,6 +32,8 @@ def derive_model_provider(model_id: str | None) -> str:
if not model_id:
return "unknown"
lower = model_id.lower()
if lower.startswith("stability-ai/"):
return "stability"
if any(lower.startswith(p) for p in GOOGLE_MODEL_PREFIXES):
return "google"
if any(lower.startswith(p) for p in OPENROUTER_PREFIXES):
Expand Down
2 changes: 2 additions & 0 deletions app/core/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Provider implementations - lazy imports to avoid circular issues
from app.core.providers.google import GoogleProvider
from app.core.providers.openrouter import OpenRouterProvider
from app.core.providers.stability import StabilityProvider

__all__ = [
# Base classes
Expand All @@ -34,4 +35,5 @@
# Implementations
"GoogleProvider",
"OpenRouterProvider",
"StabilityProvider",
]
Loading
Loading