diff --git a/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py b/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py index 3b3f434ad..32e81c8e2 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py +++ b/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py @@ -59,7 +59,7 @@ async def _find_semantic_seeds( rows = await conn.fetch( f""" SELECT id, text, context, event_date, occurred_start, occurred_end, - mentioned_at, fact_type, document_id, chunk_id, tags, + mentioned_at, fact_type, document_id, chunk_id, tags, proof_count, 1 - (embedding <=> $1::vector) AS similarity FROM {fq_table("memory_units")} WHERE bank_id = $2 @@ -274,7 +274,7 @@ async def _expand_combined( -- Score = COUNT(DISTINCT shared entities), mapped to [0,1] via tanh. SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, - mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, + mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count, COUNT(DISTINCT ue_seed.entity_id)::float AS score, 'entity'::text AS source FROM {ue} ue_seed @@ -298,14 +298,14 @@ async def _expand_combined( SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, - fact_type, document_id, chunk_id, tags, + fact_type, document_id, chunk_id, tags, proof_count, MAX(weight) AS score, 'semantic'::text AS source FROM ( SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, - mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, + mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count, ml.weight FROM {ml} ml JOIN {mu} mu ON mu.id = ml.to_unit_id @@ -317,7 +317,7 @@ async def _expand_combined( SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, - mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, + mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count, ml.weight FROM {ml} ml JOIN {mu} mu ON mu.id = ml.from_unit_id @@ -328,7 +328,7 @@ async def _expand_combined( ) sem_raw GROUP BY id, text, context, event_date, occurred_start, occurred_end, mentioned_at, - fact_type, document_id, chunk_id, tags + fact_type, document_id, chunk_id, tags, proof_count ORDER BY score DESC LIMIT $3 ), @@ -339,7 +339,7 @@ async def _expand_combined( SELECT DISTINCT ON (mu.id) mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, - mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, + mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count, ml.weight AS score, 'causal'::text AS source FROM {ml} ml @@ -429,7 +429,7 @@ async def _expand_observations( SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, - mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, + mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count, (SELECT COUNT(DISTINCT s) FROM unnest(mu.source_memory_ids) s WHERE s = ANY(ca.source_ids))::float AS score FROM {fq_table("memory_units")} mu, connected_array ca WHERE mu.fact_type = 'observation' @@ -453,13 +453,13 @@ async def _expand_observations( SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, - fact_type, document_id, chunk_id, tags, + fact_type, document_id, chunk_id, tags, proof_count, MAX(weight) AS score, 'semantic'::text AS source FROM ( SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id, - mu.chunk_id, mu.tags, ml.weight + mu.chunk_id, mu.tags, mu.proof_count, ml.weight FROM {ml} ml JOIN {mu} mu ON mu.id = ml.to_unit_id WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.link_type = 'semantic' AND mu.fact_type = 'observation' @@ -467,21 +467,21 @@ async def _expand_observations( UNION ALL SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id, - mu.chunk_id, mu.tags, ml.weight + mu.chunk_id, mu.tags, mu.proof_count, ml.weight FROM {ml} ml JOIN {mu} mu ON mu.id = ml.from_unit_id WHERE ml.to_unit_id = ANY($1::uuid[]) AND ml.link_type = 'semantic' AND mu.fact_type = 'observation' AND mu.id != ALL($1::uuid[]) ) sem_raw GROUP BY id, text, context, event_date, occurred_start, occurred_end, - mentioned_at, fact_type, document_id, chunk_id, tags + mentioned_at, fact_type, document_id, chunk_id, tags, proof_count ORDER BY score DESC LIMIT $2 ), causal_expanded AS ( SELECT DISTINCT ON (mu.id) mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id, - mu.chunk_id, mu.tags, ml.weight AS score, 'causal'::text AS source + mu.chunk_id, mu.tags, mu.proof_count, ml.weight AS score, 'causal'::text AS source FROM {ml} ml JOIN {mu} mu ON ml.to_unit_id = mu.id WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.link_type IN ('causes', 'caused_by', 'enables', 'prevents') diff --git a/hindsight-api-slim/hindsight_api/engine/search/reranking.py b/hindsight-api-slim/hindsight_api/engine/search/reranking.py index 7d0c37cfe..ae9e0fc93 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/reranking.py +++ b/hindsight-api-slim/hindsight_api/engine/search/reranking.py @@ -2,6 +2,7 @@ Cross-encoder neural reranking for search results. """ +import math from datetime import datetime, timezone from .types import MergedCandidate, ScoredResult @@ -13,6 +14,7 @@ # so the max combined boost is (1 + alpha/2)^2 ≈ +21% and min is (1 - alpha/2)^2 ≈ -19%. _RECENCY_ALPHA: float = 0.2 _TEMPORAL_ALPHA: float = 0.2 +_PROOF_COUNT_ALPHA: float = 0.1 # Conservative: max ±5% for evidence strength def apply_combined_scoring( @@ -20,28 +22,40 @@ def apply_combined_scoring( now: datetime, recency_alpha: float = _RECENCY_ALPHA, temporal_alpha: float = _TEMPORAL_ALPHA, + proof_count_alpha: float = _PROOF_COUNT_ALPHA, ) -> None: """Apply combined scoring to a list of ScoredResults in-place. - Uses the cross-encoder score as the primary relevance signal, with recency - and temporal proximity applied as multiplicative boosts. This ensures the - influence of these secondary signals is always proportional to the base - relevance score, regardless of the cross-encoder model's score calibration. + Uses the cross-encoder score as the primary relevance signal, with recency, + temporal proximity, and proof count applied as multiplicative boosts. This + ensures the influence of these secondary signals is always proportional to + the base relevance score, regardless of the cross-encoder model's score + calibration. Formula:: - recency_boost = 1 + recency_alpha * (recency - 0.5) # in [1-α/2, 1+α/2] - temporal_boost = 1 + temporal_alpha * (temporal - 0.5) # in [1-α/2, 1+α/2] - combined_score = cross_encoder_score_normalized * recency_boost * temporal_boost + recency_boost = 1 + recency_alpha * (recency - 0.5) # in [1-α/2, 1+α/2] + temporal_boost = 1 + temporal_alpha * (temporal - 0.5) # in [1-α/2, 1+α/2] + proof_count_boost = 1 + proof_count_alpha * (proof_norm - 0.5) # in [1-α/2, 1+α/2] + combined_score = CE_normalized * recency_boost * temporal_boost * proof_count_boost + + proof_norm maps proof_count using a smooth logarithmic curve centered at 0.5, + clamped to [0, 1]: + proof_count=1 → 0.5 + 0 = 0.5 (neutral multiplier) + proof_count=150 → clamped to 1.0 (max +5% boost) Temporal proximity is treated as neutral (0.5) when not set by temporal retrieval, so temporal_boost collapses to 1.0 for non-temporal queries. + Proof count is treated as neutral (0.5) when not available (non-observation facts), + so proof_count_boost collapses to 1.0 for world/experience/opinion facts. + Args: scored_results: Results from the cross-encoder reranker. Mutated in place. now: Current UTC datetime for recency calculation. recency_alpha: Max relative recency adjustment (default 0.2 → ±10%). temporal_alpha: Max relative temporal adjustment (default 0.2 → ±10%). + proof_count_alpha: Max relative proof count adjustment (default 0.1 → ±5%). """ if now.tzinfo is None: now = now.replace(tzinfo=UTC) @@ -59,13 +73,23 @@ def apply_combined_scoring( # Temporal proximity: meaningful only for temporal queries; neutral otherwise. sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5 + # Proof count: log-normalized evidence strength; neutral for non-observations. + proof_count = sr.retrieval.proof_count + if proof_count is not None and proof_count >= 1: + # Clamp to [0, 1] so extreme counts stay within documented ±5% range + proof_norm = min(1.0, max(0.0, 0.5 + (math.log(proof_count) / 10.0))) + else: + # Neutral baseline is precisely 0.5, ensuring neutral multiplier (1.0) + proof_norm = 0.5 + # RRF: kept at 0.0 for trace continuity but excluded from scoring. # RRF is batch-relative (min-max normalised) and redundant after reranking. sr.rrf_normalized = 0.0 recency_boost = 1.0 + recency_alpha * (sr.recency - 0.5) temporal_boost = 1.0 + temporal_alpha * (sr.temporal - 0.5) - sr.combined_score = sr.cross_encoder_score_normalized * recency_boost * temporal_boost + proof_count_boost = 1.0 + proof_count_alpha * (proof_norm - 0.5) + sr.combined_score = sr.cross_encoder_score_normalized * recency_boost * temporal_boost * proof_count_boost sr.weight = sr.combined_score diff --git a/hindsight-api-slim/hindsight_api/engine/search/retrieval.py b/hindsight-api-slim/hindsight_api/engine/search/retrieval.py index 78e6bed7a..f047bc4da 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/retrieval.py +++ b/hindsight-api-slim/hindsight_api/engine/search/retrieval.py @@ -141,7 +141,7 @@ async def retrieve_semantic_bm25_combined( cols = ( "id, text, context, event_date, occurred_start, occurred_end, mentioned_at, " - "fact_type, document_id, chunk_id, tags, metadata" + "fact_type, document_id, chunk_id, tags, metadata, proof_count" ) table = fq_table("memory_units") @@ -336,7 +336,7 @@ async def retrieve_temporal_combined( {groups_clause} ), sim_ranked AS ( - SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.metadata, + SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.proof_count, mu.document_id, mu.chunk_id, mu.tags, mu.metadata, 1 - (mu.embedding <=> $1::vector) AS similarity, ROW_NUMBER() OVER (PARTITION BY mu.fact_type ORDER BY mu.embedding <=> $1::vector) AS sim_rn FROM date_ranked dr @@ -344,7 +344,7 @@ async def retrieve_temporal_combined( WHERE dr.rn <= 50 AND (1 - (mu.embedding <=> $1::vector)) >= $6 ) - SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, fact_type, document_id, chunk_id, tags, metadata, similarity + SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, fact_type, proof_count, document_id, chunk_id, tags, metadata, similarity FROM sim_ranked WHERE sim_rn <= 10 """, diff --git a/hindsight-api-slim/hindsight_api/engine/search/types.py b/hindsight-api-slim/hindsight_api/engine/search/types.py index f49741ee3..9b464fc90 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/types.py +++ b/hindsight-api-slim/hindsight_api/engine/search/types.py @@ -48,6 +48,7 @@ class RetrievalResult: chunk_id: str | None = None tags: list[str] | None = None # Visibility scope tags metadata: dict[str, str] | None = None # User-provided metadata + proof_count: int | None = None # Number of supporting memories (observations only) # Retrieval-specific scores (only one will be set depending on retrieval method) similarity: float | None = None # Semantic retrieval @@ -72,6 +73,7 @@ def from_db_row(cls, row: dict[str, Any]) -> "RetrievalResult": chunk_id=row.get("chunk_id"), tags=row.get("tags"), metadata=row.get("metadata"), + proof_count=row.get("proof_count"), similarity=row.get("similarity"), bm25_score=row.get("bm25_score"), activation=row.get("activation"), diff --git a/hindsight-api-slim/tests/test_reranking_proof_count.py b/hindsight-api-slim/tests/test_reranking_proof_count.py new file mode 100644 index 000000000..3fe3e341f --- /dev/null +++ b/hindsight-api-slim/tests/test_reranking_proof_count.py @@ -0,0 +1,96 @@ +""" +Unit tests for proof_count boost in reranking. +""" + +from datetime import datetime, timezone +import pytest +from uuid import uuid4 + +from hindsight_api.engine.search.types import RetrievalResult, MergedCandidate, ScoredResult +from hindsight_api.engine.search.reranking import apply_combined_scoring + +UTC = timezone.utc + +def create_mock_scored_result(proof_count: int | None = None, ce_score: float = 0.8) -> ScoredResult: + """Helper to create a minimal ScoredResult suitable for scoring tests.""" + retrieval = RetrievalResult( + id=uuid4(), + text="Test mock fact", + fact_type="observation" if proof_count is not None else "world", + document_id=uuid4(), + chunk_id=uuid4(), + embedding=[0.1]*384, + similarity=0.9, + proof_count=proof_count, + # Default neutral dates for testing so only proof_count changes score + occurred_start=datetime.now(UTC), + occurred_end=datetime.now(UTC) + ) + candidate = MergedCandidate( + id=retrieval.id, + retrieval=retrieval, + semantic_rank=1, + bm25_rank=1, + rrf_score=0.1 + ) + return ScoredResult( + candidate=candidate, + cross_encoder_score=ce_score, + cross_encoder_score_normalized=ce_score, + weight=ce_score, + ) + +def test_proof_count_neutral_when_none(): + """Test that when proof_count is None (e.g. non-observation), it gets neutral 0.5 norm.""" + sr = create_mock_scored_result(proof_count=None, ce_score=0.8) + now = datetime.now(UTC) + + apply_combined_scoring([sr], now, proof_count_alpha=0.1) + + # Neutral multiplier means score shouldn't be boosted by proof_count + # Since recency is neutral (just created) and temporal is neutral, score should remain unchanged + assert sr.combined_score == pytest.approx(0.8, rel=1e-3) + +def test_proof_count_neutral_at_one(): + """Test that proof_count=1 gives neutral multiplier.""" + sr = create_mock_scored_result(proof_count=1, ce_score=0.8) + now = datetime.now(UTC) + + apply_combined_scoring([sr], now, proof_count_alpha=0.1) + + # proof_count=1 -> math.log(1) = 0 -> 0.5 + 0/10 = 0.5 (neutral) -> multiplier 1.0 + assert sr.combined_score == pytest.approx(0.8, rel=1e-3) + +def test_proof_count_increases_with_higher_counts(): + """Test that higher proof counts yield strictly higher scores.""" + now = datetime.now(UTC) + + # Create results with increasing proof counts + sr_5 = create_mock_scored_result(proof_count=5, ce_score=0.8) + sr_50 = create_mock_scored_result(proof_count=50, ce_score=0.8) + sr_100 = create_mock_scored_result(proof_count=100, ce_score=0.8) + + # Process them + apply_combined_scoring([sr_5, sr_50, sr_100], now, proof_count_alpha=0.1) + + # Assure scores strictly increase + assert sr_5.combined_score > 0.8 + assert sr_50.combined_score > sr_5.combined_score + assert sr_100.combined_score > sr_50.combined_score + +def test_proof_count_no_hardcoded_cap_at_100(): + """Test that proof_count continues to scale within the clamped [0, 1] range.""" + now = datetime.now(UTC) + + # Use values that stay below the clamp ceiling (proof_norm < 1.0) + # log(5)/10=0.16, log(20)/10=0.30, log(100)/10=0.46 → all below 0.5 headroom + sr_5 = create_mock_scored_result(proof_count=5, ce_score=0.8) + sr_20 = create_mock_scored_result(proof_count=20, ce_score=0.8) + sr_100 = create_mock_scored_result(proof_count=100, ce_score=0.8) + + apply_combined_scoring([sr_5, sr_20, sr_100], now, proof_count_alpha=0.1) + + # Must strictly increase within the valid range + assert sr_20.combined_score > sr_5.combined_score + assert sr_100.combined_score > sr_20.combined_score +