diff --git a/deployments/api/src/stitch/api/db/init_job.py b/deployments/api/src/stitch/api/db/init_job.py index 5a256c2..3556121 100644 --- a/deployments/api/src/stitch/api/db/init_job.py +++ b/deployments/api/src/stitch/api/db/init_job.py @@ -281,6 +281,9 @@ def create_seed_sources(): GemSourceModel.from_entity( GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5) ), + GemSourceModel.from_entity( + GemData(name="Merge Target Field", country="YYZ", lat=13.37, lon=13.37) + ), ] for i, src in enumerate(gem_sources, start=1): src.id = i @@ -294,6 +297,13 @@ def create_seed_sources(): WMSourceModel.from_entity( WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0) ), + WMSourceModel.from_entity( + WMData( + field_name="Merge Consumed Field", + field_country="YYZ", + production=1337.0, + ) + ), ] for i, src in enumerate(wm_sources, start=1): src.id = i @@ -324,6 +334,8 @@ def create_seed_resources(user: UserEntity) -> list[ResourceModel]: resources = [ ResourceModel.create(user, name="Multi-Source Asset", country="USA"), ResourceModel.create(user, name="Single Source Asset", country="GBR"), + ResourceModel.create(user, name="Merge Demo", country="YYZ"), + ResourceModel.create(user, name="Merge Demo", country="YYZ"), ] for i, res in enumerate(resources, start=1): res.id = i @@ -342,6 +354,8 @@ def create_seed_memberships( MembershipModel.create(user, resources[0], "wm", wm_sources[0].id), MembershipModel.create(user, resources[0], "rmi", rmi_sources[0].id), MembershipModel.create(user, resources[1], "gem", gem_sources[1].id), + MembershipModel.create(user, resources[2], "gem", gem_sources[2].id), + MembershipModel.create(user, resources[3], "wm", wm_sources[2].id), ] for i, mem in enumerate(memberships, start=1): mem.id = i diff --git a/deployments/api/src/stitch/api/db/resource_actions.py b/deployments/api/src/stitch/api/db/resource_actions.py index c0d19bc..1e0b091 100644 --- a/deployments/api/src/stitch/api/db/resource_actions.py +++ b/deployments/api/src/stitch/api/db/resource_actions.py @@ -154,3 +154,52 @@ async def create_source_data(session: AsyncSession, data: CreateSourceData): rmi={rmi.id: rmi.as_entity() for rmi in rmis}, cc={cc.id: cc.as_entity() for cc in ccs}, ) + + +async def merge_resources( + session: AsyncSession, + user: CurrentUser, + ids: Sequence[int], +) -> Resource: + """ + Stub "merge" behavior: + - Treat ids[0] as the canonical/target resource. + - Update all resources in ids[1:] to have repointed_id = ids[0]. + + NOTE: This only updates the resource table repointing field (no membership/source consolidation). + """ + if not ids: + raise HTTPException(status_code=400, detail="No resource IDs provided.") + # preserve order but drop duplicates + unique_ids = list(dict.fromkeys(ids)) + if len(unique_ids) < 2: + raise HTTPException( + status_code=400, detail="Provide at least 2 unique resource IDs." + ) + + target_id = unique_ids[0] + other_ids = unique_ids[1:] + + # Ensure target exists + target_model = await session.get(ResourceModel, target_id) + if target_model is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"No Resource with id `{target_id}` found.", + ) + + # Ensure all others exist, then repoint them + for rid in other_ids: + model = await session.get(ResourceModel, rid) + if model is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"No Resource with id `{rid}` found.", + ) + model.repointed_id = target_id + + await session.flush() + + # Return the canonical resource entity + await session.refresh(target_model, ["memberships"]) + return await resource_model_to_entity(session, target_model) diff --git a/deployments/api/src/stitch/api/routers/resources.py b/deployments/api/src/stitch/api/routers/resources.py index 8d63c15..6c9f53b 100644 --- a/deployments/api/src/stitch/api/routers/resources.py +++ b/deployments/api/src/stitch/api/routers/resources.py @@ -1,12 +1,17 @@ +import logging + from collections.abc import Sequence -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from pydantic import BaseModel from stitch.api.db import resource_actions from stitch.api.db.config import UnitOfWorkDep from stitch.api.auth import CurrentUser from stitch.api.entities import CreateResource, Resource +logger = logging.getLogger(__name__) router = APIRouter( prefix="/resources", @@ -33,3 +38,42 @@ async def create_resource( return await resource_actions.create( session=uow.session, user=user, resource=resource_in ) + + +class MergeRequest(BaseModel): + resource_ids: list[int] + + +@router.post("/merge", response_model=Resource) +async def merge_resources_endpoint( + *, uow: UnitOfWorkDep, user: CurrentUser, payload: MergeRequest +) -> Resource: + """ + Merge multiple resources into one (STUB): + repoint resource_ids[1:] -> resource_ids[0] + """ + ids = payload.resource_ids + # preserve order but drop duplicates + unique_ids = list(dict.fromkeys(ids)) + if len(unique_ids) < 2: + raise HTTPException( + status_code=400, detail="Provide at least 2 unique resource IDs" + ) + + logger.info( + "Merge requested by user=%s for resource_ids=%s", + getattr(user, "sub", ""), + unique_ids, + ) + + try: + return await resource_actions.merge_resources( + session=uow.session, + user=user, + ids=unique_ids, + ) + except HTTPException: + raise + except Exception as exc: + logger.exception("Error while merging resources %s: %s", unique_ids, exc) + raise HTTPException(status_code=500, detail="Internal error during merge") diff --git a/deployments/entity-linkage/Dockerfile b/deployments/entity-linkage/Dockerfile new file mode 100644 index 0000000..05d959b --- /dev/null +++ b/deployments/entity-linkage/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.12-slim-trixie + +ENV PYTHONUNBUFFERED=1 + +WORKDIR /app + +# Small, self-contained runtime (not tied to the uv workspace/lock). +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --no-cache-dir --upgrade pip \ + && pip install --no-cache-dir httpx==0.28.1 + +COPY deployments/entity-linkage/entity_linkage.py /app/entity_linkage.py + +# Defaults (override via compose/env) +ENV API_URL="http://api:8000" \ + ENTITY_LINKAGE_MODE="oneshot" \ + ENTITY_LINKAGE_SLEEP_SECONDS="10" \ + ENTITY_LINKAGE_TIMEOUT_SECONDS="10" \ + ENTITY_LINKAGE_MAX_RETRIES="60" \ + ENTITY_LINKAGE_RETRY_BACKOFF_SECONDS="1" \ + ENTITY_LINKAGE_LOG_LEVEL="INFO" + +CMD ["python", "/app/entity_linkage.py"] diff --git a/deployments/entity-linkage/README.md b/deployments/entity-linkage/README.md new file mode 100644 index 0000000..c641acf --- /dev/null +++ b/deployments/entity-linkage/README.md @@ -0,0 +1,21 @@ +# entity-linkage (deployment) + +A small client container that: +1) makes a GET request to the Stitch API +2) makes a POST request to the Stitch API + +Note that for now, it does not terminate (runs in loop looking for resources to +merge) + +Note that the the merging logic is trivial at this point (exact match on +resource name and country). + +## Configuration + +- `API_URL` (required) + - Example: `http://api:8000` +- `ENTITY_LINKAGE_SLEEP_SECONDS` (default: `10`) +- `ENTITY_LINKAGE_TIMEOUT_SECONDS` (default: `10`) +- `ENTITY_LINKAGE_MAX_RETRIES` (default: `60`) +- `ENTITY_LINKAGE_RETRY_BACKOFF_SECONDS` (default: `1`) +- `ENTITY_LINKAGE_LOG_LEVEL` (default: `INFO`) diff --git a/deployments/entity-linkage/entity_linkage.py b/deployments/entity-linkage/entity_linkage.py new file mode 100644 index 0000000..64d6c22 --- /dev/null +++ b/deployments/entity-linkage/entity_linkage.py @@ -0,0 +1,237 @@ +import json +import logging +import os +import sys +import time +from dataclasses import dataclass +from typing import Any, Dict +from urllib.parse import urljoin, urlparse +from collections import defaultdict +from typing import Tuple + +import httpx + + +@dataclass(frozen=True) +class Config: + api_url: str + sleep_seconds: float + timeout_seconds: float + max_retries: int + retry_backoff_seconds: float + log_level: str + + @staticmethod + def from_env() -> "Config": + api_url = os.getenv("API_URL", "").strip() + if not api_url: + raise ValueError("API_URL must be set (e.g. http://api:8000)") + + parsed = urlparse(api_url) + if parsed.scheme not in ("http", "https") or not parsed.netloc: + raise ValueError(f"API_URL must be a valid http(s) URL; got: {api_url!r}") + + def get_float(name: str, default: str) -> float: + raw = os.getenv(name, default).strip() + return float(raw) + + def get_int(name: str, default: str) -> int: + raw = os.getenv(name, default).strip() + return int(raw) + + return Config( + api_url=api_url.rstrip("/") + "/", + sleep_seconds=get_float("ENTITY_LINKAGE_SLEEP_SECONDS", "10"), + timeout_seconds=get_float("ENTITY_LINKAGE_TIMEOUT_SECONDS", "10"), + max_retries=get_int("ENTITY_LINKAGE_MAX_RETRIES", "60"), + retry_backoff_seconds=get_float( + "ENTITY_LINKAGE_RETRY_BACKOFF_SECONDS", "1" + ), + log_level=os.getenv("ENTITY_LINKAGE_LOG_LEVEL", "INFO").strip().upper(), + ) + + +def setup_logging(level: str) -> None: + logging.basicConfig( + level=getattr(logging, level, logging.INFO), + format="%(asctime)s %(levelname)s %(name)s - %(message)s", + ) + + +log = logging.getLogger("entity-linkage") + + +def _safe_json(obj: Any) -> str: + try: + return json.dumps(obj, ensure_ascii=False, default=str) + except Exception: + return "" + + +def wait_for_api(cfg: Config) -> None: + """ + Minimal readiness probe: repeatedly GET health (preferred), + else fall back to GET resources/ if health isn't present. + """ + candidates = [ + urljoin(cfg.api_url, "health"), + urljoin(cfg.api_url, "resources/"), + ] + + timeout = httpx.Timeout(cfg.timeout_seconds) + with httpx.Client(timeout=timeout) as client: + for attempt in range(1, cfg.max_retries + 1): + for url in candidates: + try: + r = client.get(url) + if 200 <= r.status_code < 500: + # 2xx/3xx means good, 4xx means server is up even if route differs. + log.info("API reachable: %s (status=%s)", url, r.status_code) + return + log.warning( + "API not ready yet: %s (status=%s, body=%s)", + url, + r.status_code, + r.text[:300], + ) + except Exception as e: + log.warning("API probe failed: %s (%s)", url, e) + + if attempt < cfg.max_retries: + time.sleep(cfg.retry_backoff_seconds) + + raise RuntimeError( + f"API not reachable after {cfg.max_retries} retries; last tried: {candidates}" + ) + + +def extract_duplicate_groups(items: list[dict]) -> list[Tuple[str, str, list[int]]]: + """ + Groups resources by (name, country) and returns only groups + that contain more than one item. + + Returns: + [ + (name, country, [id1, id2, ...]), + ... + ] + """ + groups: dict[tuple[str, str], list[int]] = defaultdict(list) + + for item in items: + try: + name = item["name"] + country = item["country"] + rid = item["id"] + except KeyError as e: + log.warning("Skipping item missing expected field %s: %s", e, item) + continue + + groups[(name, country)].append(rid) + + duplicates: list[Tuple[str, str, list[int]]] = [] + + for (name, country), ids in groups.items(): + if len(ids) > 1: + duplicates.append((name, country, ids)) + + return duplicates + + +def do_get_then_post(cfg: Config) -> None: + timeout = httpx.Timeout(cfg.timeout_seconds) + + get_url = urljoin(cfg.api_url, "resources/") + post_url = urljoin(cfg.api_url, "resources/merge") + + with httpx.Client(timeout=timeout) as client: + # ---- GET ---- + log.info("GET %s", get_url) + r_get = client.get(get_url) + + log.info( + "GET response status=%s", + r_get.status_code, + ) + + if r_get.status_code >= 500: + raise RuntimeError(f"GET failed with status {r_get.status_code}") + + try: + data = r_get.json() + except Exception: + log.error("GET did not return valid JSON. Body=%s", r_get.text[:1000]) + raise + + if not isinstance(data, list): + raise RuntimeError("Expected GET to return a JSON array") + + log.info("Fetched %s resources", len(data)) + + duplicate_groups = extract_duplicate_groups(data) + + if not duplicate_groups: + log.info("No duplicate (name, country) groups found.") + else: + log.info("Found %s duplicate groups.", len(duplicate_groups)) + for name, country, ids in duplicate_groups: + log.info( + "Duplicate group detected: name=%r country=%r ids=%s", + name, + country, + ids, + ) + + # ---- POST (stub) ---- + payload: Dict[str, Any] = {"resource_ids": ids} + + log.info("prepost POST %s payload=%s", post_url, _safe_json(payload)) + r_post = client.post(post_url, json=payload) + + log.info( + "POST response status=%s body=%s", + r_post.status_code, + r_post.text[:1000], + ) + + if r_post.status_code >= 500: + raise RuntimeError(f"POST failed with status {r_post.status_code}") + + +def main() -> int: + try: + cfg = Config.from_env() + except Exception as e: + print(f"[entity-linkage] config error: {e}", file=sys.stderr) + return 2 + + setup_logging(cfg.log_level) + log.info( + "Starting (api_url=%s, sleep=%ss)", + cfg.api_url, + cfg.sleep_seconds, + ) + + try: + wait_for_api(cfg) + except Exception: + log.exception("API did not become reachable") + return 3 + + while True: + start = time.time() + try: + do_get_then_post(cfg) + log.info( + "Iteration ok (elapsed=%.2fs). Sleeping %ss.", + time.time() - start, + cfg.sleep_seconds, + ) + except Exception: + log.exception("Iteration failed. Sleeping %ss.", cfg.sleep_seconds) + + time.sleep(cfg.sleep_seconds) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/docker-compose.yml b/docker-compose.yml index 0175c1f..b167c2d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -82,5 +82,16 @@ services: ports: - "3000:80" + entity-linkage: + build: + context: . + dockerfile: deployments/entity-linkage/Dockerfile + env_file: + - .env + depends_on: + api: + condition: service_started + restart: "no" + volumes: db_data: diff --git a/env.example b/env.example index a51f749..6974379 100644 --- a/env.example +++ b/env.example @@ -12,6 +12,11 @@ STITCH_DB_SEED_MODE="if-needed" STITCH_DB_SEED_PROFILE="dev" FRONTEND_ORIGIN_URL=http://localhost:3000 +API_URL=http://api:8000/api/v1 # Auth (AUTH_DISABLED=true bypasses JWT validation for local dev) AUTH_DISABLED=true + +# entity-linkage +ENTITY_LINKAGE_SLEEP_SECONDS=10 +ENTITY_LINKAGE_LOG_LEVEL=INFO