Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 269 additions & 9 deletions src/graphstore/bonsai_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,126 @@ def _scrape_belief_updates(
facts[fact_id] = st


# --------------------------------------------------------------------
# Compact output mode: LLM emits 3 tagged lines (ENTS/BELIEFS/RETRACTS);
# we synthesize the full DSL in Python. 3-5x fewer output tokens than the
# full-DSL mode, measured on 4B TQ1_0. See tools/skills/graphstore-bonsai-
# dsl-compact/SKILL.md for the exact output contract.
# --------------------------------------------------------------------

# One "key"="value" pair, capturing both sides. "key" matches ent: or fact:
# prefixes; "value" is everything between the escaped-quote-aware delimiters.
_COMPACT_KV_RE = re.compile(r'"([^"\\]+(?:\\.[^"\\]*)*)"\s*=\s*"([^"\\]*(?:\\.[^"\\]*)*)"')
# Bare-id list item (RETRACTS uses these).
_COMPACT_ID_RE = re.compile(r'"([^"\\]+(?:\\.[^"\\]*)*)"')


@dataclass
class CompactTurn:
"""Parsed structured output of a compact-mode LLM call."""

entities: list[tuple[str, str]] = field(default_factory=list) # [(ent_id, name), ...]
beliefs: list[tuple[str, str]] = field(default_factory=list) # [(fact_id, value), ...]
retracts: list[str] = field(default_factory=list) # [fact_id, ...]


def _parse_compact_output(cleaned: str) -> CompactTurn:
"""Read the 3-line ENTS/BELIEFS/RETRACTS output.

Tolerant: missing sections default to empty, unknown prefixes ignored,
case-insensitive on section labels, honors `none` as empty.
"""
turn = CompactTurn()
for raw_ln in cleaned.splitlines():
ln = raw_ln.strip()
if not ln or _FENCE_RE.match(ln):
continue
lower = ln.lower()
if lower.startswith("ents:"):
body = ln[5:].strip()
if body.lower() in ("none", ""):
continue
for m in _COMPACT_KV_RE.finditer(body):
turn.entities.append((m.group(1), m.group(2)))
elif lower.startswith("beliefs:"):
body = ln[8:].strip()
if body.lower() in ("none", ""):
continue
for m in _COMPACT_KV_RE.finditer(body):
turn.beliefs.append((m.group(1), m.group(2)))
elif lower.startswith("retracts:"):
body = ln[9:].strip()
if body.lower() in ("none", ""):
continue
for m in _COMPACT_ID_RE.finditer(body):
turn.retracts.append(m.group(1))
return turn


def _dsl_escape(s: str) -> str:
"""Escape a Python string for safe embedding inside a DSL "..." literal."""
return s.replace("\\", "\\\\").replace('"', '\\"')


def _synthesize_dsl(
turn: CompactTurn,
*,
msg_id: str,
session_id: str,
role: str,
text: str,
) -> list[str]:
"""Build the full DSL statement list from the parsed compact output.

Deterministic. Same CompactTurn + same identifiers always produce the
same list of statements. Emits:
1. CREATE NODE for the message (DOCUMENT = user text).
2. UPSERT NODE per entity + matching CREATE EDGE kind = "mentions".
Entities are deduped by id (first wins).
3. RETRACT per retract (before any ASSERT).
4. ASSERT per belief.
"""
out: list[str] = []
text_esc = _dsl_escape(text)
session_esc = _dsl_escape(session_id)
role_esc = _dsl_escape(role)
msg_esc = _dsl_escape(msg_id)

out.append(
f'CREATE NODE "{msg_esc}" kind = "message" '
f'session = "{session_esc}" role = "{role_esc}" '
f'DOCUMENT "{text_esc}"'
)

ordered_ents: list[str] = []
seen_ents: set[str] = set()
for ent_id, name in turn.entities:
if ent_id in seen_ents:
continue
seen_ents.add(ent_id)
ordered_ents.append(ent_id)
out.append(
f'UPSERT NODE "{_dsl_escape(ent_id)}" kind = "entity" name = "{_dsl_escape(name)}"'
)
for ent_id in ordered_ents:
out.append(
f'CREATE EDGE "{msg_esc}" -> "{_dsl_escape(ent_id)}" kind = "mentions"'
)

for fact_id in turn.retracts:
out.append(
f'RETRACT "{_dsl_escape(fact_id)}" REASON "superseded by {msg_esc}"'
)

for fact_id, value in turn.beliefs:
out.append(
f'ASSERT "{_dsl_escape(fact_id)}" kind = "belief" '
f'value = "{_dsl_escape(value)}" CONFIDENCE 0.9 SOURCE "{msg_esc}"'
)

return out


def _render_known_facts_block(facts: dict[str, FactState], max_facts: int = 40) -> str:
"""Format non-retracted facts into a block the LLM reads before the input.

Expand Down Expand Up @@ -265,6 +385,11 @@ def _render_known_facts_block(facts: dict[str, FactState], max_facts: int = 40)
/ "tools" / "skills" / "graphstore-bonsai-dsl" / "SKILL.md"
)

_DEFAULT_COMPACT_SKILL_PATH = (
Path(__file__).resolve().parent.parent.parent
/ "tools" / "skills" / "graphstore-bonsai-dsl-compact" / "SKILL.md"
)


class BonsaiIngestor:
"""NL -> DSL via a local llama.cpp GGUF, with correctness guards.
Expand Down Expand Up @@ -303,19 +428,27 @@ def __init__(
*,
gs: Any | None = None,
skill_path: str | Path | None = None,
compact: bool = False,
n_ctx: int = 2048,
n_threads: int | None = None,
chat_format: str = "qwen",
max_output_tokens: int = 400,
temperature: float = 0.0,
kv_cache_path: str | Path | None = None,
) -> None:
self._model_path = Path(model_path)
if not self._model_path.exists():
raise FileNotFoundError(f"bonsai model not found: {self._model_path}")
self._gs = gs
self._skill_path = Path(skill_path) if skill_path else _DEFAULT_SKILL_PATH
self._compact = compact
if skill_path:
self._skill_path = Path(skill_path)
else:
self._skill_path = _DEFAULT_COMPACT_SKILL_PATH if compact else _DEFAULT_SKILL_PATH
self._n_ctx = n_ctx
self._max_output_tokens = max_output_tokens
# Compact mode emits ~30 tokens of structured output. Cap lower so
# stray model verbosity doesn't burn decode time.
self._max_output_tokens = max_output_tokens if not compact else min(max_output_tokens, 160)
self._temperature = temperature
self._chat_format = chat_format
self._n_threads = n_threads
Expand All @@ -332,6 +465,12 @@ def __init__(
# user message of the next ingest so the model reuses ids.
self._facts: dict[str, FactState] = {}

# Optional persistent KV cache. Eliminates the ~10s cold penalty on
# process restarts. File holds a pickled (meta, LlamaState) tuple;
# meta guards against loading stale state when the skill or config
# changed since the cache was written.
self._kv_cache_path = Path(kv_cache_path) if kv_cache_path else None

# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
Expand All @@ -354,6 +493,83 @@ def _reload_skill(self) -> None:
self._skill_fingerprint = hashlib.sha256(body.encode()).hexdigest()[:12]
self._system_prompt = f"# skill-sha256={self._skill_fingerprint}\n\n{body}"

def _kv_meta(self) -> dict[str, Any]:
"""What the current config looks like. Written alongside the KV cache
so we can refuse to load state if any of these changed."""
return {
"model_path": str(self._model_path),
"model_size_bytes": self._model_path.stat().st_size,
"skill_fingerprint": self._skill_fingerprint,
"n_ctx": self._n_ctx,
"chat_format": self._chat_format,
}

def _try_load_kv_cache(self, llm: Any) -> bool:
"""Load a persisted KV cache into `llm` if one exists and is valid.

Returns True on successful load, False otherwise. Invalid cache is
silently ignored - the caller warms up normally.
"""
if not self._kv_cache_path or not self._kv_cache_path.exists():
return False
import pickle

try:
with self._kv_cache_path.open("rb") as f:
payload = pickle.load(f)
except Exception as err:
_log.warning("bonsai: KV cache unreadable (%s); skipping", err)
return False

meta = payload.get("meta") if isinstance(payload, dict) else None
state = payload.get("state") if isinstance(payload, dict) else None
if not meta or state is None:
_log.warning("bonsai: KV cache shape invalid; skipping")
return False

cur = self._kv_meta()
if meta != cur:
diff = {k: (meta.get(k), cur.get(k)) for k in cur if meta.get(k) != cur.get(k)}
_log.info(
"bonsai: KV cache stale (diff=%s); warming fresh",
diff,
)
return False

try:
llm.load_state(state)
except Exception as err:
_log.warning("bonsai: KV cache load_state failed (%s); warming fresh", err)
return False

_log.info("bonsai: KV cache loaded from %s (skipped warmup)", self._kv_cache_path)
return True

def save_kv_cache(self) -> None:
"""Persist the current Llama instance's KV state to `kv_cache_path`.

Call after `warmup()` (or after one real ingest) so the skill-prefix
tokens are in the cache. The file is (meta, LlamaState) pickled.

No-op if kv_cache_path was not configured or the Llama hasn't been
constructed yet.
"""
if not self._kv_cache_path or self._llm is None:
return
import pickle

state = self._llm.save_state()
self._kv_cache_path.parent.mkdir(parents=True, exist_ok=True)
tmp = self._kv_cache_path.with_suffix(self._kv_cache_path.suffix + ".tmp")
with tmp.open("wb") as f:
pickle.dump({"meta": self._kv_meta(), "state": state}, f)
tmp.replace(self._kv_cache_path)
_log.info(
"bonsai: KV cache saved to %s (%.1f MB)",
self._kv_cache_path,
self._kv_cache_path.stat().st_size / 1e6,
)

def _ensure_llm(self) -> Any:
"""Lazy-load the Llama instance on first use."""
if self._llm is not None:
Expand All @@ -372,6 +588,7 @@ def _ensure_llm(self) -> Any:
self._model_path.name, self._n_ctx, self._n_threads, self._chat_format,
)
self._llm = Llama(**kwargs)
self._try_load_kv_cache(self._llm)
return self._llm

def reset(self) -> None:
Expand Down Expand Up @@ -429,22 +646,57 @@ def warmup(self) -> None:
temperature=0.0,
)

def ingest(self, text: str, *, dry_run: bool = False) -> IngestResult:
def ingest(
self,
text: str,
*,
msg_id: str | None = None,
session_id: str = "default",
role: str = "user",
dry_run: bool = False,
) -> IngestResult:
"""Convert `text` to DSL statements and (optionally) execute them.

`dry_run=True` returns the DSL without touching the store - useful
for previewing or building training data without committing.
In full-DSL mode (compact=False) the LLM emits DSL directly; msg_id
and session_id come from the text the caller supplies ("Session s1,
msg m:s1:0, user: ...") so the extra kwargs are unused.

In compact mode (compact=True) the LLM emits ENTS/BELIEFS/RETRACTS
and Python synthesizes the DSL. The caller must pass msg_id (and
may override session_id / role); these become the identifiers in
the synthesized CREATE NODE / CREATE EDGE statements.

`dry_run=True` returns the DSL without touching the store.
"""
if not text or not text.strip():
raise IngestEmpty("input text is empty or whitespace-only")
if not dry_run and self._gs is None:
raise ValueError("ingest requires a GraphStore (pass gs=...) or dry_run=True")
if self._compact and not msg_id:
raise ValueError(
"compact=True ingest requires an explicit msg_id "
"(DSL synthesis needs the exact CREATE NODE id)"
)

self._reload_skill()
with self._lock:
return self._ingest_locked(text, dry_run=dry_run)
return self._ingest_locked(
text,
msg_id=msg_id,
session_id=session_id,
role=role,
dry_run=dry_run,
)

def _ingest_locked(self, text: str, *, dry_run: bool) -> IngestResult:
def _ingest_locked(
self,
text: str,
*,
msg_id: str | None,
session_id: str,
role: str,
dry_run: bool,
) -> IngestResult:
t0 = time.perf_counter()
llm = self._ensure_llm()

Expand Down Expand Up @@ -485,8 +737,16 @@ def _ingest_locked(self, text: str, *, dry_run: bool) -> IngestResult:
f"raw={raw!r}"
)

raw_lines = _split_lines(cleaned)
deduped, dup_dropped = _dedupe_upserts(raw_lines)
if self._compact:
assert msg_id is not None # guarded in ingest()
turn = _parse_compact_output(cleaned)
deduped = _synthesize_dsl(
turn, msg_id=msg_id, session_id=session_id, role=role, text=text,
)
dup_dropped: list[tuple[str, str]] = []
else:
raw_lines = _split_lines(cleaned)
deduped, dup_dropped = _dedupe_upserts(raw_lines)

from graphstore.dsl.parser import parse as _dsl_parse

Expand Down
Loading
Loading