diff --git a/model/real_reference_tokens.json b/model/real_reference_tokens.json new file mode 120000 index 00000000..152f3be9 --- /dev/null +++ b/model/real_reference_tokens.json @@ -0,0 +1 @@ +/opt/gobed/model/real_reference_tokens.json \ No newline at end of file diff --git a/model/tokenizer.json b/model/tokenizer.json new file mode 120000 index 00000000..b2dd43ec --- /dev/null +++ b/model/tokenizer.json @@ -0,0 +1 @@ +/opt/gobed/model/tokenizer.json \ No newline at end of file diff --git a/src/alpaca_account_lock.py b/src/alpaca_account_lock.py new file mode 100644 index 00000000..5d16e602 --- /dev/null +++ b/src/alpaca_account_lock.py @@ -0,0 +1,129 @@ +"""Alpaca account locking and live-trading guardrails. + +Provides: +- require_explicit_live_trading_enable: env-var gated live-trading guardrail. +- acquire_alpaca_account_lock: acquire a per-account advisory file lock so that + only one writer bot can trade a given Alpaca account at a time. +""" + +from __future__ import annotations + +import fcntl +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +_TRUTHY_ENV_VALUES = {"1", "true", "yes", "y", "on"} + +# Directory under which lock files are stored. +_LOCKS_DIR = Path(os.environ.get("ALPACA_LOCKS_DIR", "/tmp/.alpaca_locks")) + + +def _is_truthy_env(name: str, *, default: bool = False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in _TRUTHY_ENV_VALUES + + +def require_explicit_live_trading_enable(bot_name: str, env_var: str = "ALPACA_ENABLE_LIVE_TRADING") -> None: + """Guardrail: raise SystemExit unless live trading is explicitly enabled. + + Args: + bot_name: Human-readable name of the bot (used in the error message). + env_var: Environment variable that must be set to a truthy value. + """ + if _is_truthy_env(env_var, default=False): + return + raise SystemExit( + f"Alpaca live trading for '{bot_name}' is disabled by default. " + f"To enable intentionally: set {env_var}=1 and re-run with --live." + ) + + +class AlpacaAccountLock: + """Advisory file lock that prevents concurrent writes to a single Alpaca account. + + Attributes: + path: Path of the lock file that is currently held. + """ + + def __init__(self, path: Path) -> None: + self.path = path + self._handle = None + + # ------------------------------------------------------------------ + # Context-manager protocol (optional – can also be used standalone) + # ------------------------------------------------------------------ + + def __enter__(self) -> "AlpacaAccountLock": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.release() + + # ------------------------------------------------------------------ + # Explicit release + # ------------------------------------------------------------------ + + def release(self) -> None: + if self._handle is not None: + try: + fcntl.flock(self._handle, fcntl.LOCK_UN) + self._handle.close() + except Exception: + pass + finally: + self._handle = None + + +def acquire_alpaca_account_lock( + bot_name: str, + *, + account_name: str = "default", + locks_dir: Path | None = None, +) -> AlpacaAccountLock: + """Acquire an exclusive advisory file lock for the given Alpaca account. + + Only one process/bot may hold the lock for a given *account_name* at any + time. The lock is released when the returned :class:`AlpacaAccountLock` + object is garbage-collected, its :meth:`~AlpacaAccountLock.release` method + is called, or it is used as a context manager. + + Args: + bot_name: Human-readable identifier for the bot acquiring the lock. + account_name: Logical name of the Alpaca account (e.g. ``"alpaca_live_writer"``). + locks_dir: Optional override for the directory that holds lock files. + + Returns: + An :class:`AlpacaAccountLock` whose ``.path`` attribute is the path of + the acquired lock file. + + Raises: + RuntimeError: If the lock cannot be acquired (e.g. because another + process already holds it). + """ + base_dir: Path = locks_dir or _LOCKS_DIR + base_dir.mkdir(parents=True, exist_ok=True) + + safe_account = "".join(c if c.isalnum() else "_" for c in account_name) + lock_path = base_dir / f"{safe_account}.lock" + + lock = AlpacaAccountLock(lock_path) + handle = lock_path.open("w") + try: + fcntl.flock(handle, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError as exc: + handle.close() + raise RuntimeError( + f"Could not acquire Alpaca account lock for '{account_name}' " + f"(bot='{bot_name}'). Another process may already hold it: {lock_path}" + ) from exc + + handle.write(f"bot={bot_name}\n") + handle.flush() + lock._handle = handle + logger.info("Acquired Alpaca account lock: bot=%s account=%s path=%s", bot_name, account_name, lock_path) + return lock diff --git a/src/trading_server/__init__.py b/src/trading_server/__init__.py new file mode 100644 index 00000000..815f456e --- /dev/null +++ b/src/trading_server/__init__.py @@ -0,0 +1 @@ +"""Trading server package – HTTP client and in-process engine.""" diff --git a/src/trading_server/client.py b/src/trading_server/client.py new file mode 100644 index 00000000..ed203417 --- /dev/null +++ b/src/trading_server/client.py @@ -0,0 +1,122 @@ +"""HTTP client for the TradingServer REST API. + +Provides :class:`TradingServerClient` – a thin wrapper around ``requests`` +that talks to a running :class:`~src.trading_server.server.TradingServerEngine` +via HTTP. +""" + +from __future__ import annotations + +import logging +from typing import Iterable, Optional + +logger = logging.getLogger(__name__) + + +class TradingServerClient: + """REST client for the TradingServer. + + Args: + base_url: Base URL of the trading server (e.g. ``"http://localhost:8080"``). + If ``None``, defaults to ``"http://localhost:8080"``. + account: Logical account name (e.g. ``"paper_daily_sortino"``). + bot_id: Bot identifier used for writer-lock claims. + session_id: Stable session identifier for idempotent operations. + execution_mode: ``"paper"`` or ``"live"``. + """ + + def __init__( + self, + *, + base_url: Optional[str] = None, + account: str, + bot_id: str, + session_id: Optional[str] = None, + execution_mode: str = "paper", + ) -> None: + self.base_url = (base_url or "http://localhost:8080").rstrip("/") + self.account = account + self.bot_id = bot_id + self.session_id = session_id or f"{bot_id}-session" + self.execution_mode = execution_mode + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _post(self, path: str, body: dict) -> dict: + import requests # lazy import so the module is importable without requests + + url = f"{self.base_url}{path}" + response = requests.post(url, json=body, timeout=30) + response.raise_for_status() + return response.json() + + def _get(self, path: str, params: Optional[dict] = None) -> dict: + import requests + + url = f"{self.base_url}{path}" + response = requests.get(url, params=params or {}, timeout=30) + response.raise_for_status() + return response.json() + + # ------------------------------------------------------------------ + # API methods + # ------------------------------------------------------------------ + + def claim_writer(self, *, ttl_seconds: int = 120) -> dict: + """Claim exclusive writer access for this bot on the account.""" + return self._post( + "/writer/claim", + { + "account": self.account, + "bot_id": self.bot_id, + "session_id": self.session_id, + "ttl_seconds": ttl_seconds, + }, + ) + + def refresh_prices(self, *, symbols: Optional[Iterable[str]] = None) -> dict: + """Request the server to refresh cached prices for the given symbols.""" + return self._post( + "/prices/refresh", + { + "account": self.account, + "symbols": list(symbols or []), + }, + ) + + def get_account(self) -> dict: + """Return a snapshot of the account state (cash, positions, history).""" + return self._get("/account/snapshot", {"account": self.account}) + + def submit_limit_order( + self, + *, + symbol: str, + side: str, + qty: float, + limit_price: float, + allow_loss_exit: bool = False, + force_exit_reason: Optional[str] = None, + live_ack: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> dict: + """Submit a limit order to the trading server.""" + return self._post( + "/orders/submit", + { + "account": self.account, + "bot_id": self.bot_id, + "session_id": self.session_id, + "symbol": symbol, + "side": side, + "qty": qty, + "limit_price": limit_price, + "execution_mode": self.execution_mode, + "allow_loss_exit": allow_loss_exit, + "force_exit_reason": force_exit_reason, + "live_ack": live_ack, + "metadata": metadata or {}, + }, + ) diff --git a/src/trading_server/server.py b/src/trading_server/server.py new file mode 100644 index 00000000..766a2985 --- /dev/null +++ b/src/trading_server/server.py @@ -0,0 +1,249 @@ +"""In-process TradingServerEngine for deterministic paper-trading backtests. + +:class:`TradingServerEngine` implements the same interface as the remote HTTP +trading server but runs entirely in memory – no network calls, no persistence +between instantiations (unless *state_dir* is provided). +""" + +from __future__ import annotations + +import json +import logging +import uuid +from collections import defaultdict +from datetime import datetime, timezone +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Callable, Dict, Iterable, List, Optional + +logger = logging.getLogger(__name__) + + +class _AccountState: + """Mutable state for a single paper-trading account.""" + + def __init__( + self, + *, + starting_cash: float, + allowed_bot_id: str, + symbols: List[str], + sell_loss_cooldown_seconds: int = 0, + min_sell_markup_pct: float = 0.0, + ) -> None: + self.cash: float = starting_cash + self.allowed_bot_id = allowed_bot_id + self.symbols = [str(s).upper() for s in symbols] + self.sell_loss_cooldown_seconds = sell_loss_cooldown_seconds + self.min_sell_markup_pct = min_sell_markup_pct + # symbol -> {qty, avg_entry_price, opened_at} + self.positions: Dict[str, Dict[str, Any]] = {} + self.order_history: List[Dict[str, Any]] = [] + # Current writer session + self.writer_bot_id: Optional[str] = None + self.writer_session_id: Optional[str] = None + + +class TradingServerEngine: + """In-process trading server engine used for deterministic backtesting. + + The engine supports multiple named accounts, each loaded from the + *registry_path* JSON file. State is held entirely in memory; optional + *state_dir* is reserved for future persistence support. + + Args: + registry_path: Path to a JSON file that describes accounts. Format:: + + { + "accounts": { + "": { + "mode": "paper", + "allowed_bot_id": "", + "starting_cash": 10000.0, + "symbols": ["AAPL", "MSFT"], + "sell_loss_cooldown_seconds": 0, + "min_sell_markup_pct": 0.0 + } + } + } + + state_dir: Optional directory for state snapshots (not yet used). + quote_provider: Callable ``(symbol: str) -> dict | None`` that returns + the current quote for a symbol. + now_fn: Callable ``() -> datetime`` used to determine the current + time (injectable for deterministic backtests). + """ + + def __init__( + self, + *, + registry_path: Path, + state_dir: Optional[Path] = None, + quote_provider: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None, + now_fn: Optional[Callable[[], datetime]] = None, + ) -> None: + self._quote_provider = quote_provider or (lambda symbol: None) + self._now_fn = now_fn or (lambda: datetime.now(timezone.utc)) + self._accounts: Dict[str, _AccountState] = {} + self._prices: Dict[str, float] = {} + + registry = json.loads(Path(registry_path).read_text(encoding="utf-8")) + for acct_name, cfg in registry.get("accounts", {}).items(): + self._accounts[acct_name] = _AccountState( + starting_cash=float(cfg.get("starting_cash", 10_000.0)), + allowed_bot_id=str(cfg.get("allowed_bot_id", "")), + symbols=cfg.get("symbols", []), + sell_loss_cooldown_seconds=int(cfg.get("sell_loss_cooldown_seconds", 0)), + min_sell_markup_pct=float(cfg.get("min_sell_markup_pct", 0.0)), + ) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _now(self) -> datetime: + return self._now_fn() + + def _get_account(self, account: str) -> _AccountState: + state = self._accounts.get(account) + if state is None: + raise ValueError(f"Unknown account: {account!r}") + return state + + def _current_price(self, symbol: str) -> float: + sym = str(symbol).upper() + cached = self._prices.get(sym) + if cached is not None: + return cached + quote = self._quote_provider(sym) + if quote and isinstance(quote, dict): + for key in ("last_price", "bid_price", "ask_price"): + v = quote.get(key) + if v: + return float(v) + return 0.0 + + # ------------------------------------------------------------------ + # Engine API (mirroring the InMemoryTradingServerClient adapter) + # ------------------------------------------------------------------ + + def claim_writer(self, request: Any) -> Dict[str, Any]: + """Claim exclusive write access for a bot on an account.""" + account = str(getattr(request, "account", "")) + bot_id = str(getattr(request, "bot_id", "")) + session_id = str(getattr(request, "session_id", "")) + state = self._get_account(account) + state.writer_bot_id = bot_id + state.writer_session_id = session_id + return {"status": "ok", "account": account, "bot_id": bot_id, "session_id": session_id} + + def refresh_prices(self, *, account: str, symbols: Iterable[str]) -> Dict[str, Any]: + """Refresh cached prices for the given symbols.""" + updated: Dict[str, float] = {} + for sym in symbols: + sym = str(sym).upper() + quote = self._quote_provider(sym) + if quote and isinstance(quote, dict): + for key in ("last_price", "bid_price", "ask_price"): + v = quote.get(key) + if v: + self._prices[sym] = float(v) + updated[sym] = float(v) + break + return {"status": "ok", "updated": updated} + + def get_account_snapshot(self, account: str) -> Dict[str, Any]: + """Return a snapshot of the account: cash, positions, order_history.""" + state = self._get_account(account) + positions_out: Dict[str, Any] = {} + for sym, pos in state.positions.items(): + qty = float(pos.get("qty", 0.0)) + if qty == 0.0: + continue + avg_entry_price = float(pos.get("avg_entry_price", 0.0)) + current_price = self._current_price(sym) or avg_entry_price + positions_out[sym] = { + "qty": qty, + "avg_entry_price": avg_entry_price, + "current_price": current_price, + "opened_at": pos.get("opened_at"), + } + equity = state.cash + sum( + float(p["qty"]) * self._current_price(sym) + for sym, p in state.positions.items() + ) + return { + "account": account, + "cash": state.cash, + "equity": equity, + "positions": positions_out, + "order_history": list(state.order_history), + } + + def submit_order(self, request: Any) -> Dict[str, Any]: + """Execute a limit order immediately (paper mode).""" + account = str(getattr(request, "account", "")) + symbol = str(getattr(request, "symbol", "")).upper() + side = str(getattr(request, "side", "")).lower() + qty = float(getattr(request, "qty", 0.0)) + limit_price = float(getattr(request, "limit_price", 0.0)) + metadata = getattr(request, "metadata", {}) or {} + + state = self._get_account(account) + now_iso = self._now().isoformat() + order_id = str(uuid.uuid4()) + + if side == "buy": + cost = qty * limit_price + if cost > state.cash: + # Reduce qty to what cash allows + qty = max(0.0, state.cash / limit_price) + cost = qty * limit_price + if qty > 0: + state.cash -= cost + existing = state.positions.get(symbol) + if existing: + total_qty = float(existing["qty"]) + qty + existing["avg_entry_price"] = ( + float(existing["avg_entry_price"]) * float(existing["qty"]) + limit_price * qty + ) / total_qty + existing["qty"] = total_qty + else: + state.positions[symbol] = { + "qty": qty, + "avg_entry_price": limit_price, + "opened_at": now_iso, + } + + elif side == "sell": + existing = state.positions.get(symbol) + if existing: + sell_qty = min(qty, float(existing["qty"])) + proceeds = sell_qty * limit_price + state.cash += proceeds + remaining = float(existing["qty"]) - sell_qty + if remaining <= 1e-6: + del state.positions[symbol] + else: + existing["qty"] = remaining + + order_record = { + "id": order_id, + "symbol": symbol, + "side": side, + "qty": qty, + "limit_price": limit_price, + "filled_price": limit_price, + "filled_at": now_iso, + "metadata": metadata, + } + state.order_history.append(order_record) + logger.debug( + "TradingServerEngine: %s %s %.4f @ %.4f (account=%s)", + side.upper(), + symbol, + qty, + limit_price, + account, + ) + return {"status": "ok", "order": order_record}