Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ requires-python = ">=3.8"
dependencies = [
# intentionally loose. perhaps these should be vendored to not collide with user code?
"attrs>=20.1,<24",
"cryptography>=46.0.3",
"fastapi>=0.100,<0.119.0",
"pydantic>=1.9,<3",
"PyYAML",
Expand Down
4 changes: 4 additions & 0 deletions python/cog/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existence of both secret and Secret with very different implementations is a source of confusion I would really love to avoid 😅 WDYT about a verb prefix like get_secret or load_secret or rolling it together with non-secret config values like a cog.getenv that automatically handles encrypted values?

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from .base_predictor import BasePredictor
from .mimetypes_ext import install_mime_extensions
from .secret import (
load_secret,
)
from .server.scope import current_scope
from .types import (
AsyncConcatenateIterator,
Expand Down Expand Up @@ -34,5 +37,6 @@
"File",
"Input",
"Path",
"load_secret",
"Secret",
]
88 changes: 88 additions & 0 deletions python/cog/secret.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to see COG_ENV_LOCATION and COG_PUBLIC_KEY_LOCATION_ENV_VAR_KEY injected as defaults rather than referencing the module-level variables directly in the __init__.

Is the intended flow for the cog process to generate the RSA key pair and the remote secret_url also has access to the public key? I guess I don't follow the order of operations here, since it seems like the values loaded from .cog/.env need to be encrypted with a key pair that isn't known until just before reading the .cog/.env values 🤔

Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import base64
import os
from pathlib import Path

import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from dotenv import dotenv_values

__all__ = [
"load_secret",
"default_secret_provider",
]


def load_secret(name: str, secret_provider: SecretProvider | None) -> str:
if not secret_provider:
secret_provider = default_secret_provider
return secret_provider.get_secret(name)


class SecretProvider:
def __init__(
self,
cog_env_location: str = ".cog/.env",
cog_public_key_env_var: str = "COG_PUBLIC_KEY_LOCATION",
) -> None:
self.env = {}
self.no_public_key = False
self.key = rsa.generate_private_key(
backend=default_backend(),
public_exponent=65537,
key_size=2048,
)
self.secret_url: str | None = None
public_pem = self.key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_path_raw = os.getenv(cog_public_key_env_var)
if not public_key_path_raw:
self.no_public_key = True
return
public_key_path = Path(public_key_path_raw)
public_key_path.parent.mkdir(mode=0o700, exist_ok=True)
public_key_path.touch()
public_key_path.write_bytes(public_pem)
if not os.path.isfile(cog_env_location):
return
self.env = dotenv_values(cog_env_location)

def get_secret(self, secret_name: str) -> str:
# Try to get the secret from the remote. Fall back to the local
# env file (local development only)
try:
if not self.secret_url:
raise ValueError("No secret URL passed")
if self.no_public_key:
raise ValueError("No public key for encryption")
raw_secret = os.getenv(secret_name)
if not raw_secret:
raise ValueError("No matching secret")
response = requests.post(
f"{self.secret_url}",
json={
"value": raw_secret,
},
)
response.raise_for_status()

plaintext_bytes = self.key.decrypt(
base64.b64decode(response.text),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)

return plaintext_bytes.decode("utf-8")
except Exception:
return self.env.get(secret_name) or ""


default_secret_provider = SecretProvider()
4 changes: 4 additions & 0 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..json import upload_files
from ..logging import setup_logging
from ..mode import Mode
from ..secret import default_secret_provider
from ..types import PYDANTIC_V2

try:
Expand Down Expand Up @@ -126,12 +127,15 @@ def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-s
shutdown_event: Optional[threading.Event], # pylint: disable=redefined-outer-name
app_threads: Optional[int] = None,
upload_url: Optional[str] = None,
secrets_url: Optional[str] = None,
mode: Mode = Mode.PREDICT,
is_build: bool = False,
await_explicit_shutdown: bool = False, # pylint: disable=redefined-outer-name
) -> MyFastAPI:
started_at = datetime.now(tz=timezone.utc)

default_secret_provider.secret_url = secrets_url

@asynccontextmanager
async def lifespan(app: MyFastAPI) -> AsyncGenerator[None, None]:
# Startup code (was previously in @app.on_event("startup"))
Expand Down
5 changes: 4 additions & 1 deletion python/cog/server/runner.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to me that self._secrets_url will only be present if passed into the __init__. Could the design be to always set self._secrets_url and then handle a None value elsewhere?

class Always:
    def __init__(self, bla: str | None):
        self.bla = bla


class Sometimes:
    def __init__(self, bla: str | None):
        if bla is not None:
            self.bla = bla

a0 = Always("ok")
a1 = Always(None)

s0 = Sometimes("ok")
s1 = Sometimes(None)

print("a0", a0.bla)
print("a1", a1.bla)

print("s0", s0.bla)
print("s1", s1.bla)
$ python ~/tmp/always_sometimes.py
a0 ok
a1 None
s0 ok
Traceback (most recent call last):
  File "/Users/me/tmp/always_sometimes.py", line 21, in <module>
    print("s1", s1.bla)
                ^^^^^^
AttributeError: 'Sometimes' object has no attribute 'bla'

Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def result(self) -> T:


class SetupTask(Task[SetupResult]):
def __init__(self, _clock: Optional[Callable[[], datetime]] = None) -> None:
def __init__(
self,
_clock: Optional[Callable[[], datetime]] = None,
) -> None:
log.info("starting setup")
self._clock = _clock
if self._clock is None:
Expand Down
1 change: 1 addition & 0 deletions test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
curl -X GET https://api.coreweave.com/v1beta1/cks/clusters/{id} -H "Content-Type: application/json" -H "Authorization: Bearer {API_ACCESS_TOKEN}"
Loading
Loading