Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def _find_semantic_seeds(
rows = await conn.fetch(
f"""
SELECT id, text, context, event_date, occurred_start, occurred_end,
mentioned_at, fact_type, document_id, chunk_id, tags,
mentioned_at, fact_type, document_id, chunk_id, tags, proof_count,
1 - (embedding <=> $1::vector) AS similarity
FROM {fq_table("memory_units")}
WHERE bank_id = $2
Expand Down Expand Up @@ -274,7 +274,7 @@ async def _expand_combined(
-- Score = COUNT(DISTINCT shared entities), mapped to [0,1] via tanh.
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
COUNT(DISTINCT ue_seed.entity_id)::float AS score,
'entity'::text AS source
FROM {ue} ue_seed
Expand All @@ -298,14 +298,14 @@ async def _expand_combined(
SELECT
id, text, context, event_date, occurred_start,
occurred_end, mentioned_at,
fact_type, document_id, chunk_id, tags,
fact_type, document_id, chunk_id, tags, proof_count,
MAX(weight) AS score,
'semantic'::text AS source
FROM (
SELECT
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
ml.weight
FROM {ml} ml
JOIN {mu} mu ON mu.id = ml.to_unit_id
Expand All @@ -317,7 +317,7 @@ async def _expand_combined(
SELECT
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
ml.weight
FROM {ml} ml
JOIN {mu} mu ON mu.id = ml.from_unit_id
Expand All @@ -328,7 +328,7 @@ async def _expand_combined(
) sem_raw
GROUP BY id, text, context, event_date, occurred_start,
occurred_end, mentioned_at,
fact_type, document_id, chunk_id, tags
fact_type, document_id, chunk_id, tags, proof_count
ORDER BY score DESC
LIMIT $3
),
Expand All @@ -339,7 +339,7 @@ async def _expand_combined(
SELECT DISTINCT ON (mu.id)
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
ml.weight AS score,
'causal'::text AS source
FROM {ml} ml
Expand Down Expand Up @@ -429,7 +429,7 @@ async def _expand_observations(
SELECT
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
(SELECT COUNT(DISTINCT s) FROM unnest(mu.source_memory_ids) s WHERE s = ANY(ca.source_ids))::float AS score
FROM {fq_table("memory_units")} mu, connected_array ca
WHERE mu.fact_type = 'observation'
Expand All @@ -453,35 +453,35 @@ async def _expand_observations(
SELECT
id, text, context, event_date, occurred_start,
occurred_end, mentioned_at,
fact_type, document_id, chunk_id, tags,
fact_type, document_id, chunk_id, tags, proof_count,
MAX(weight) AS score,
'semantic'::text AS source
FROM (
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
mu.chunk_id, mu.tags, ml.weight
mu.chunk_id, mu.tags, mu.proof_count, ml.weight
FROM {ml} ml JOIN {mu} mu ON mu.id = ml.to_unit_id
WHERE ml.from_unit_id = ANY($1::uuid[])
AND ml.link_type = 'semantic' AND mu.fact_type = 'observation'
AND mu.id != ALL($1::uuid[])
UNION ALL
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
mu.chunk_id, mu.tags, ml.weight
mu.chunk_id, mu.tags, mu.proof_count, ml.weight
FROM {ml} ml JOIN {mu} mu ON mu.id = ml.from_unit_id
WHERE ml.to_unit_id = ANY($1::uuid[])
AND ml.link_type = 'semantic' AND mu.fact_type = 'observation'
AND mu.id != ALL($1::uuid[])
) sem_raw
GROUP BY id, text, context, event_date, occurred_start, occurred_end,
mentioned_at, fact_type, document_id, chunk_id, tags
mentioned_at, fact_type, document_id, chunk_id, tags, proof_count
ORDER BY score DESC LIMIT $2
),
causal_expanded AS (
SELECT DISTINCT ON (mu.id)
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
mu.chunk_id, mu.tags, ml.weight AS score, 'causal'::text AS source
mu.chunk_id, mu.tags, mu.proof_count, ml.weight AS score, 'causal'::text AS source
FROM {ml} ml JOIN {mu} mu ON ml.to_unit_id = mu.id
WHERE ml.from_unit_id = ANY($1::uuid[])
AND ml.link_type IN ('causes', 'caused_by', 'enables', 'prevents')
Expand Down
40 changes: 32 additions & 8 deletions hindsight-api-slim/hindsight_api/engine/search/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Cross-encoder neural reranking for search results.
"""

import math
from datetime import datetime, timezone

from .types import MergedCandidate, ScoredResult
Expand All @@ -13,35 +14,48 @@
# so the max combined boost is (1 + alpha/2)^2 ≈ +21% and min is (1 - alpha/2)^2 ≈ -19%.
_RECENCY_ALPHA: float = 0.2
_TEMPORAL_ALPHA: float = 0.2
_PROOF_COUNT_ALPHA: float = 0.1 # Conservative: max ±5% for evidence strength


def apply_combined_scoring(
scored_results: list[ScoredResult],
now: datetime,
recency_alpha: float = _RECENCY_ALPHA,
temporal_alpha: float = _TEMPORAL_ALPHA,
proof_count_alpha: float = _PROOF_COUNT_ALPHA,
) -> None:
"""Apply combined scoring to a list of ScoredResults in-place.

Uses the cross-encoder score as the primary relevance signal, with recency
and temporal proximity applied as multiplicative boosts. This ensures the
influence of these secondary signals is always proportional to the base
relevance score, regardless of the cross-encoder model's score calibration.
Uses the cross-encoder score as the primary relevance signal, with recency,
temporal proximity, and proof count applied as multiplicative boosts. This
ensures the influence of these secondary signals is always proportional to
the base relevance score, regardless of the cross-encoder model's score
calibration.

Formula::

recency_boost = 1 + recency_alpha * (recency - 0.5) # in [1-α/2, 1+α/2]
temporal_boost = 1 + temporal_alpha * (temporal - 0.5) # in [1-α/2, 1+α/2]
combined_score = cross_encoder_score_normalized * recency_boost * temporal_boost
recency_boost = 1 + recency_alpha * (recency - 0.5) # in [1-α/2, 1+α/2]
temporal_boost = 1 + temporal_alpha * (temporal - 0.5) # in [1-α/2, 1+α/2]
proof_count_boost = 1 + proof_count_alpha * (proof_norm - 0.5) # in [1-α/2, 1+α/2]
combined_score = CE_normalized * recency_boost * temporal_boost * proof_count_boost

proof_norm maps proof_count using a smooth logarithmic curve centered at 0.5,
clamped to [0, 1]:
proof_count=1 → 0.5 + 0 = 0.5 (neutral multiplier)
proof_count=150 → clamped to 1.0 (max +5% boost)

Temporal proximity is treated as neutral (0.5) when not set by temporal retrieval,
so temporal_boost collapses to 1.0 for non-temporal queries.

Proof count is treated as neutral (0.5) when not available (non-observation facts),
so proof_count_boost collapses to 1.0 for world/experience/opinion facts.

Args:
scored_results: Results from the cross-encoder reranker. Mutated in place.
now: Current UTC datetime for recency calculation.
recency_alpha: Max relative recency adjustment (default 0.2 → ±10%).
temporal_alpha: Max relative temporal adjustment (default 0.2 → ±10%).
proof_count_alpha: Max relative proof count adjustment (default 0.1 → ±5%).
"""
if now.tzinfo is None:
now = now.replace(tzinfo=UTC)
Expand All @@ -59,13 +73,23 @@ def apply_combined_scoring(
# Temporal proximity: meaningful only for temporal queries; neutral otherwise.
sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5

# Proof count: log-normalized evidence strength; neutral for non-observations.
proof_count = sr.retrieval.proof_count
if proof_count is not None and proof_count >= 1:
# Clamp to [0, 1] so extreme counts stay within documented ±5% range
proof_norm = min(1.0, max(0.0, 0.5 + (math.log(proof_count) / 10.0)))
else:
# Neutral baseline is precisely 0.5, ensuring neutral multiplier (1.0)
proof_norm = 0.5

# RRF: kept at 0.0 for trace continuity but excluded from scoring.
# RRF is batch-relative (min-max normalised) and redundant after reranking.
sr.rrf_normalized = 0.0

recency_boost = 1.0 + recency_alpha * (sr.recency - 0.5)
temporal_boost = 1.0 + temporal_alpha * (sr.temporal - 0.5)
sr.combined_score = sr.cross_encoder_score_normalized * recency_boost * temporal_boost
proof_count_boost = 1.0 + proof_count_alpha * (proof_norm - 0.5)
sr.combined_score = sr.cross_encoder_score_normalized * recency_boost * temporal_boost * proof_count_boost
sr.weight = sr.combined_score


Expand Down
6 changes: 3 additions & 3 deletions hindsight-api-slim/hindsight_api/engine/search/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def retrieve_semantic_bm25_combined(

cols = (
"id, text, context, event_date, occurred_start, occurred_end, mentioned_at, "
"fact_type, document_id, chunk_id, tags, metadata"
"fact_type, document_id, chunk_id, tags, metadata, proof_count"
)
table = fq_table("memory_units")

Expand Down Expand Up @@ -336,15 +336,15 @@ async def retrieve_temporal_combined(
{groups_clause}
),
sim_ranked AS (
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.metadata,
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.proof_count, mu.document_id, mu.chunk_id, mu.tags, mu.metadata,
1 - (mu.embedding <=> $1::vector) AS similarity,
ROW_NUMBER() OVER (PARTITION BY mu.fact_type ORDER BY mu.embedding <=> $1::vector) AS sim_rn
FROM date_ranked dr
JOIN {fq_table("memory_units")} mu ON mu.id = dr.id
WHERE dr.rn <= 50
AND (1 - (mu.embedding <=> $1::vector)) >= $6
)
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, fact_type, document_id, chunk_id, tags, metadata, similarity
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, fact_type, proof_count, document_id, chunk_id, tags, metadata, similarity
FROM sim_ranked
WHERE sim_rn <= 10
""",
Expand Down
2 changes: 2 additions & 0 deletions hindsight-api-slim/hindsight_api/engine/search/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class RetrievalResult:
chunk_id: str | None = None
tags: list[str] | None = None # Visibility scope tags
metadata: dict[str, str] | None = None # User-provided metadata
proof_count: int | None = None # Number of supporting memories (observations only)

# Retrieval-specific scores (only one will be set depending on retrieval method)
similarity: float | None = None # Semantic retrieval
Expand All @@ -72,6 +73,7 @@ def from_db_row(cls, row: dict[str, Any]) -> "RetrievalResult":
chunk_id=row.get("chunk_id"),
tags=row.get("tags"),
metadata=row.get("metadata"),
proof_count=row.get("proof_count"),
similarity=row.get("similarity"),
bm25_score=row.get("bm25_score"),
activation=row.get("activation"),
Expand Down
96 changes: 96 additions & 0 deletions hindsight-api-slim/tests/test_reranking_proof_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Unit tests for proof_count boost in reranking.
"""

from datetime import datetime, timezone
import pytest
from uuid import uuid4

from hindsight_api.engine.search.types import RetrievalResult, MergedCandidate, ScoredResult
from hindsight_api.engine.search.reranking import apply_combined_scoring

UTC = timezone.utc

def create_mock_scored_result(proof_count: int | None = None, ce_score: float = 0.8) -> ScoredResult:
"""Helper to create a minimal ScoredResult suitable for scoring tests."""
retrieval = RetrievalResult(
id=uuid4(),
text="Test mock fact",
fact_type="observation" if proof_count is not None else "world",
document_id=uuid4(),
chunk_id=uuid4(),
embedding=[0.1]*384,
similarity=0.9,
proof_count=proof_count,
# Default neutral dates for testing so only proof_count changes score
occurred_start=datetime.now(UTC),
occurred_end=datetime.now(UTC)
)
candidate = MergedCandidate(
id=retrieval.id,
retrieval=retrieval,
semantic_rank=1,
bm25_rank=1,
rrf_score=0.1
)
return ScoredResult(
candidate=candidate,
cross_encoder_score=ce_score,
cross_encoder_score_normalized=ce_score,
weight=ce_score,
)

def test_proof_count_neutral_when_none():
"""Test that when proof_count is None (e.g. non-observation), it gets neutral 0.5 norm."""
sr = create_mock_scored_result(proof_count=None, ce_score=0.8)
now = datetime.now(UTC)

apply_combined_scoring([sr], now, proof_count_alpha=0.1)

# Neutral multiplier means score shouldn't be boosted by proof_count
# Since recency is neutral (just created) and temporal is neutral, score should remain unchanged
assert sr.combined_score == pytest.approx(0.8, rel=1e-3)

def test_proof_count_neutral_at_one():
"""Test that proof_count=1 gives neutral multiplier."""
sr = create_mock_scored_result(proof_count=1, ce_score=0.8)
now = datetime.now(UTC)

apply_combined_scoring([sr], now, proof_count_alpha=0.1)

# proof_count=1 -> math.log(1) = 0 -> 0.5 + 0/10 = 0.5 (neutral) -> multiplier 1.0
assert sr.combined_score == pytest.approx(0.8, rel=1e-3)

def test_proof_count_increases_with_higher_counts():
"""Test that higher proof counts yield strictly higher scores."""
now = datetime.now(UTC)

# Create results with increasing proof counts
sr_5 = create_mock_scored_result(proof_count=5, ce_score=0.8)
sr_50 = create_mock_scored_result(proof_count=50, ce_score=0.8)
sr_100 = create_mock_scored_result(proof_count=100, ce_score=0.8)

# Process them
apply_combined_scoring([sr_5, sr_50, sr_100], now, proof_count_alpha=0.1)

# Assure scores strictly increase
assert sr_5.combined_score > 0.8
assert sr_50.combined_score > sr_5.combined_score
assert sr_100.combined_score > sr_50.combined_score

def test_proof_count_no_hardcoded_cap_at_100():
"""Test that proof_count continues to scale within the clamped [0, 1] range."""
now = datetime.now(UTC)

# Use values that stay below the clamp ceiling (proof_norm < 1.0)
# log(5)/10=0.16, log(20)/10=0.30, log(100)/10=0.46 → all below 0.5 headroom
sr_5 = create_mock_scored_result(proof_count=5, ce_score=0.8)
sr_20 = create_mock_scored_result(proof_count=20, ce_score=0.8)
sr_100 = create_mock_scored_result(proof_count=100, ce_score=0.8)

apply_combined_scoring([sr_5, sr_20, sr_100], now, proof_count_alpha=0.1)

# Must strictly increase within the valid range
assert sr_20.combined_score > sr_5.combined_score
assert sr_100.combined_score > sr_20.combined_score