From 6ab1cd858bce9ba4a587ef7837aeefb7949ec2bd Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Mon, 27 Apr 2026 13:37:53 -0400 Subject: [PATCH 1/3] Add vector comparison support --- pyproject.toml | 4 + src/configuration.py | 8 ++ src/constants.py | 2 + src/models/config.py | 105 +++++++++++++++ src/utils/vector_search.py | 264 +++++++++++++++++++++++++++++++++---- 5 files changed, 357 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b16398e95..80ed3a4e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,11 @@ dependencies = [ "jinja2>=3.1.0", # To be able to fix multiple CVEs, also LCORE-1117 "requests>=2.33.0", + # Used for RAG chunk reranking (cross-encoder) + "sentence-transformers>=2.0.0", "datasets>=4.7.0", + # Used for RAG chunk reranking (cross-encoder) + "sentence-transformers>=2.0.0", # Used for error tracking and monitoring "sentry-sdk[fastapi]>=2.58.0", ] diff --git a/src/configuration.py b/src/configuration.py index 4eeca0460..c6a376327 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -27,6 +27,7 @@ OkpConfiguration, QuotaHandlersConfiguration, RagConfiguration, + RerankerConfiguration, RlsapiV1Configuration, ServiceConfiguration, SplunkConfiguration, @@ -465,6 +466,13 @@ def okp(self) -> "OkpConfiguration": raise LogicError("logic error: configuration is not loaded") return self._configuration.okp + @property + def reranker(self) -> "RerankerConfiguration": + """Return reranker configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.reranker + @property def rag_id_mapping(self) -> dict[str, str]: """Return mapping from vector_db_id to rag_id from BYOK and OKP RAG config. diff --git a/src/constants.py b/src/constants.py index 79bbbf0e5..ed4d79b22 100644 --- a/src/constants.py +++ b/src/constants.py @@ -192,6 +192,8 @@ # Inline RAG constants BYOK_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from BYOK RAG OKP_RAG_MAX_CHUNKS: Final[int] = 5 # retrieved from OKP RAG +# Score multiplier applied to BYOK chunks after cross-encoder reranking (Solr chunks unchanged) +BYOK_RAG_RERANK_BOOST = 1.2 # Solr OKP constants SOLR_VECTOR_SEARCH_DEFAULT_K: Final[int] = 5 diff --git a/src/models/config.py b/src/models/config.py index 2e7397bca..e540aea24 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1812,6 +1812,63 @@ class OkpConfiguration(ConfigurationBase): ) +class RerankerConfiguration(ConfigurationBase): + """Reranker configuration for RAG chunk reranking.""" + + enabled: bool = True + model: str = Field( + default="cross-encoder/ms-marco-MiniLM-L6-v2", + title="Reranker model", + description="Cross-encoder model name for reranking RAG chunks. " + "Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.", + ) + model_id: str = "meta-llama/llama-cross-encoder-base" + provider_id: str = "meta-reference" + top_k_multiplier: float = 2.0 # fetch 2x, rerank, keep top_k + byok_boost: float = 1.2 + okp_boost: float = 1.0 + + # Private attribute to track if this was explicitly configured + _explicitly_configured: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def mark_as_explicitly_configured(self) -> Self: + """Mark this configuration as explicitly set when instantiated from user input.""" + # Only mark as explicitly configured if we're not using all default values + # This allows auto-enabling when user hasn't touched reranker settings + # Check if any field differs from default values + default_model = "cross-encoder/ms-marco-MiniLM-L6-v2" + default_model_id = "meta-llama/llama-cross-encoder-base" + default_provider_id = "meta-reference" + default_top_k_multiplier = 2.0 + default_byok_boost = 1.2 + default_okp_boost = 1.0 + + # Check if any setting differs from defaults (indicates explicit configuration) + current_values = [ + self.enabled, + self.model, + self.model_id, + self.provider_id, + self.top_k_multiplier, + self.byok_boost, + self.okp_boost, + ] + default_values = [ + True, + default_model, + default_model_id, + default_provider_id, + default_top_k_multiplier, + default_byok_boost, + default_okp_boost, + ] + + if current_values != default_values: + self._explicitly_configured = True + return self + + class AzureEntraIdConfiguration(ConfigurationBase): """Microsoft Entra ID authentication attributes for Azure.""" @@ -1970,6 +2027,12 @@ class Configuration(ConfigurationBase): "in rag.inline or rag.tool.", ) + reranker: RerankerConfiguration = Field( + default_factory=RerankerConfiguration, + title="Reranker configuration", + description="Configuration for neural reranking of RAG chunks using cross-encoder.", + ) + @model_validator(mode="after") def validate_mcp_auth_headers(self) -> Self: """ @@ -2072,6 +2135,48 @@ def validate_rlsapi_v1_quota_configuration(self) -> Self: return self + @model_validator(mode="after") + def validate_reranker_auto_enable(self) -> Self: + """Automatically enable reranker when both BYOK and OKP RAG are configured. + + When users have both BYOK (Bring Your Own Key) entries in byok_rag and OKP + (OpenShift Knowledge Platform) configured in the RAG strategies, automatically + enable the reranker if it's not explicitly disabled. This improves result + quality when multiple knowledge sources are available. + + Returns: + Self: The validated configuration instance with reranker potentially enabled. + """ + # Check if BYOK RAG entries are configured + has_byok = len(self.byok_rag) > 0 + + # Check if OKP is configured in either inline or tool RAG strategies + # pylint: disable=no-member + has_okp = ( + constants.OKP_RAG_ID in self.rag.inline + or constants.OKP_RAG_ID in self.rag.tool + ) + + # If both BYOK and OKP are present and reranker is using default settings, + # ensure it's enabled for optimal results + if ( + has_byok + and has_okp + and not hasattr(self.reranker, "_explicitly_configured") + ): + # pylint: disable=no-member + if not self.reranker.enabled: + logger.info( + "Automatically enabling reranker: Both BYOK RAG (%d entries) and OKP " + "are configured. Reranking improves result quality when multiple " + "knowledge sources are available.", + len(self.byok_rag), + ) + # pylint: disable=no-member + self.reranker.enabled = True + + return self + def dump(self, filename: str | Path = "configuration.json") -> None: """ Write the current Configuration model to a JSON file. diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index cfbeea3c5..e5363741e 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -25,6 +25,185 @@ logger = get_logger(__name__) +# Lazy-loaded cross-encoder models for reranking RAG chunks (CPU-bound, use in thread). +# Cache models by name to avoid reloading the same model multiple times. +# Not a constant; pylint invalid-name is disabled for this module-level singleton. +_cross_encoder_models: dict[str, Any] = {} # pylint: disable=invalid-name + + +def _get_cross_encoder(model_name: str) -> Any: + """Return the lazy-loaded cross-encoder model for reranking. + + Args: + model_name: Name of the cross-encoder model to load. + + Returns: + Loaded CrossEncoder model instance, or None if loading fails. + """ + if model_name not in _cross_encoder_models: + try: + from sentence_transformers import ( # pylint: disable=import-outside-toplevel + CrossEncoder, + ) + + _cross_encoder_models[model_name] = CrossEncoder(model_name) + logger.info("Loaded cross-encoder for RAG reranking: %s", model_name) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Could not load cross-encoder for reranking (%s): %s", model_name, e + ) + _cross_encoder_models[model_name] = None + return _cross_encoder_models[model_name] + + +async def _rerank_chunks_with_cross_encoder( + query: str, + chunks: list[RAGChunk], + top_k: int, + model_name: str = "cross-encoder/ms-marco-MiniLM-L6-v2", +) -> list[RAGChunk]: + """Rerank chunks using configurable cross-encoder model. + + Args: + query: The search query + chunks: RAG chunks to rerank + top_k: Number of top chunks to return + model_name: Cross-encoder model name to use + + Returns: + Top top_k chunks sorted by cross-encoder score (descending) + """ + if not chunks: + return [] + + try: + # Get the cached cross-encoder model + model = _get_cross_encoder(model_name) + if model is None: + raise RuntimeError(f"Failed to load cross-encoder model: {model_name}") + + logger.debug("Using cross-encoder model: %s", model_name) + + # Create query-chunk pairs for scoring + pairs = [(query, chunk.content) for chunk in chunks] + scores = model.predict(pairs) + + if hasattr(scores, "tolist"): + scores = scores.tolist() + + # Normalize scores to [0,1] range using min-max normalization + if len(scores) > 1: + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score + if score_range > 0: + normalized_scores = [(score - min_score) / score_range for score in scores] + else: + # All scores are identical, assign 0.5 to all + normalized_scores = [0.5] * len(scores) + else: + # Single score, assign 1.0 + normalized_scores = [1.0] * len(scores) + + # Combine normalized scores with chunks and sort by score (descending) + indexed = list(zip(normalized_scores, chunks, strict=True)) + indexed.sort(key=lambda x: x[0], reverse=True) + top_indexed = indexed[:top_k] + + # Return RAGChunk list with normalized cross-encoder scores [0,1] + return [ + RAGChunk( + content=chunk.content, + source=chunk.source, + score=float(score), + attributes=chunk.attributes, + ) + for score, chunk in top_indexed + ] + + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Cross-encoder reranking failed, falling back to original scoring: %s", e + ) + # Fallback: sort by original score and take top_k + sorted_chunks = sorted( + chunks, + key=lambda c: c.score if c.score is not None else float("-inf"), + reverse=True, + ) + return sorted_chunks[:top_k] + + +def _apply_byok_rerank_boost( + chunks: list[RAGChunk], boost: float = constants.BYOK_RAG_RERANK_BOOST +) -> list[RAGChunk]: + """Apply a score multiplier to BYOK chunks (source != OKP) and re-sort by score. + + Args: + chunks: RAG chunks after reranking (may be from BYOK or Solr). + boost: Multiplier applied to BYOK chunk scores. Solr chunks unchanged. + + Returns: + Same chunks with BYOK scores boosted, sorted by score descending. + """ + boosted = [] + for chunk in chunks: + score = chunk.score if chunk.score is not None else float("-inf") + if chunk.source != constants.OKP_RAG_ID: + score = score * boost + boosted.append( + RAGChunk( + content=chunk.content, + source=chunk.source, + score=score, + attributes=chunk.attributes, + ) + ) + boosted.sort( + key=lambda c: c.score if c.score is not None else float("-inf"), + reverse=True, + ) + return boosted + + +def _referenced_documents_from_rag_chunks( + rag_chunks: list[RAGChunk], +) -> list[ReferencedDocument]: + """Build referenced documents list from RAG chunks (e.g. after reranking). + + Args: + rag_chunks: RAG chunks with source and attributes (doc_url, title, etc.). + + Returns: + Deduplicated list of ReferencedDocument from chunk attributes. + """ + seen: set[str] = set() + result: list[ReferencedDocument] = [] + for chunk in rag_chunks: + attrs = chunk.attributes or {} + doc_url = ( + attrs.get("reference_url") or attrs.get("doc_url") or attrs.get("docs_url") + ) + doc_id = attrs.get("document_id") or attrs.get("doc_id") + dedup_key = doc_url or doc_id or chunk.source or "" + if not dedup_key or dedup_key in seen: + continue + seen.add(dedup_key) + parsed_url: Optional[AnyUrl] = None + if doc_url: + try: + parsed_url = AnyUrl(doc_url) + except Exception: # pylint: disable=broad-exception-caught + parsed_url = None + result.append( + ReferencedDocument( + doc_title=attrs.get("title"), + doc_url=parsed_url, + source=chunk.source, + ) + ) + return result + def _get_okp_base_url() -> AnyUrl: """Return OKP document base URL from configuration (rhokp_url), or default if unset. @@ -55,15 +234,18 @@ def _get_solr_vector_store_ids() -> list[str]: def _build_query_params( + solr: Optional[SolrVectorSearchRequest] = None, + k: Optional[int] = None, ) -> dict[str, Any]: """Build query parameters for Solr vector_io search. Args: solr: Optional structured Solr request (mode and filters from the API). + k: Optional number of results to return. If not provided, uses default. Returns: - Parameter dictionary for ``vector_io.query``. + Query parameters dict for vector_io.query. """ resolved_mode = ( solr.mode @@ -71,7 +253,7 @@ def _build_query_params( else constants.SOLR_VECTOR_SEARCH_DEFAULT_MODE ) params: dict[str, Any] = { - "k": constants.SOLR_VECTOR_SEARCH_DEFAULT_K, + "k": k if k is not None else constants.SOLR_VECTOR_SEARCH_DEFAULT_K, "score_threshold": constants.SOLR_VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD, "mode": resolved_mode, } @@ -180,6 +362,7 @@ async def _query_store_for_byok_rag( vector_store_id: str, query: str, weight: float, + max_chunks: int = constants.BYOK_RAG_MAX_CHUNKS, ) -> list[dict[str, Any]]: """Query a single vector store for BYOK RAG. @@ -188,6 +371,7 @@ async def _query_store_for_byok_rag( vector_store_id: ID of the vector store to query query: Search query string weight: Score multiplier to apply + max_chunks: Maximum number of chunks to request from this store. Returns: List of weighted result dictionaries, or empty list on error @@ -197,7 +381,7 @@ async def _query_store_for_byok_rag( vector_store_id=vector_store_id, query=query, params={ - "max_chunks": constants.BYOK_RAG_MAX_CHUNKS, + "max_chunks": max_chunks, "mode": "vector", }, ) @@ -344,26 +528,29 @@ def _process_solr_chunks_for_documents( return doc_ids_from_chunks -async def _fetch_byok_rag( +async def _fetch_byok_rag( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, query: str, - vector_store_ids: Optional[list[str]] = None, # User-facing + vector_store_ids: Optional[list[str]] = None, + max_chunks: Optional[int] = None, ) -> tuple[list[RAGChunk], list[ReferencedDocument]]: """Fetch chunks and documents from BYOK RAG sources. Args: client: The AsyncLlamaStackClient to use for the request query: The search query - configuration: Application configuration vector_store_ids: Optional list of vector store IDs to query. If provided, only these stores will be queried. If None, all stores (excluding Solr) will be queried. + max_chunks: Maximum number of chunks to return. If None, uses + constants.BYOK_RAG_MAX_CHUNKS. Returns: Tuple containing: - rag_chunks: RAG chunks from BYOK RAG - referenced_documents: Documents referenced in BYOK RAG results """ + limit = max_chunks if max_chunks is not None else constants.BYOK_RAG_MAX_CHUNKS rag_chunks: list[RAGChunk] = [] referenced_documents: list[ReferencedDocument] = [] @@ -410,6 +597,7 @@ async def _fetch_byok_rag( vector_store_id, query, score_multiplier_mapping.get(vector_store_id, 1.0), + max_chunks=limit, ) for vector_store_id in vector_store_ids_to_query ] @@ -420,7 +608,7 @@ async def _fetch_byok_rag( for store_results in results_per_store: all_results.extend(store_results) all_results.sort(key=lambda x: x["weighted_score"], reverse=True) - top_results = all_results[: constants.BYOK_RAG_MAX_CHUNKS] + top_results = all_results[:limit] # Resolve source, log, and convert to RAGChunk in a single pass logger.info("Filtered top %d chunks from BYOK RAG", len(top_results)) @@ -451,10 +639,11 @@ async def _fetch_byok_rag( return rag_chunks, referenced_documents -async def _fetch_solr_rag( +async def _fetch_solr_rag( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, query: str, solr: Optional[SolrVectorSearchRequest] = None, + max_chunks: Optional[int] = None, ) -> tuple[list[RAGChunk], list[ReferencedDocument]]: """Fetch chunks and documents from Solr RAG source. @@ -462,6 +651,8 @@ async def _fetch_solr_rag( client: The AsyncLlamaStackClient to use for the request query: The user's query solr: Structured Solr inline RAG request from the API (optional). + max_chunks: Maximum number of chunks to return. If None, uses + constants.OKP_RAG_MAX_CHUNKS. Returns: Tuple containing: @@ -470,6 +661,7 @@ async def _fetch_solr_rag( """ rag_chunks: list[RAGChunk] = [] referenced_documents: list[ReferencedDocument] = [] + limit = max_chunks if max_chunks is not None else constants.OKP_RAG_MAX_CHUNKS if not _is_solr_enabled(): logger.info("OKP vector IO is disabled, skipping OKP search") @@ -502,8 +694,8 @@ async def _fetch_solr_rag( ) # Limit to top N chunks - top_chunks = query_response.chunks[: constants.OKP_RAG_MAX_CHUNKS] - top_scores = retrieved_scores[: constants.OKP_RAG_MAX_CHUNKS] + top_chunks = query_response.chunks[:limit] + top_scores = retrieved_scores[:limit] # Extract referenced documents from Solr chunks referenced_documents = _process_solr_chunks_for_documents( @@ -516,7 +708,7 @@ async def _fetch_solr_rag( ) logger.debug( "Filtered top %d chunks from OKP RAG (%d were retrieved)", - constants.OKP_RAG_MAX_CHUNKS, + limit, len(rag_chunks), ) @@ -527,20 +719,22 @@ async def _fetch_solr_rag( return rag_chunks, referenced_documents -async def build_rag_context( +async def build_rag_context( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, - moderation_decision: str, + moderation_decision: str, # pylint: disable=unused-argument query: str, vector_store_ids: Optional[list[str]], solr: Optional[SolrVectorSearchRequest] = None, ) -> RAGContext: """Build RAG context by fetching and merging chunks from all enabled sources. - Enabled sources can be BYOK and/or Solr OKP. + Fetches 2 * BYOK_RAG_MAX_CHUNKS from each of BYOK and Solr, merges and keeps + top 2 * BYOK_RAG_MAX_CHUNKS by score, reranks with a cross-encoder, then + keeps the top BYOK_RAG_MAX_CHUNKS for context. Enabled sources can be BYOK + and/or Solr OKP. Args: client: The AsyncLlamaStackClient to use for the request - moderation_decision: The moderation decision query: The user's query vector_store_ids: The vector store IDs to query solr: Structured Solr inline RAG request from the API (optional). @@ -548,30 +742,47 @@ async def build_rag_context( Returns: RAGContext containing formatted context text and referenced documents """ - if moderation_decision == "blocked": - return RAGContext() + pool_size = 2 * constants.BYOK_RAG_MAX_CHUNKS + top_k = constants.BYOK_RAG_MAX_CHUNKS - # Fetch from all enabled RAG sources in parallel - byok_chunks_task = _fetch_byok_rag(client, query, vector_store_ids) + # Fetch 2*BYOK_RAG_MAX_CHUNKS from each source in parallel + byok_chunks_task = _fetch_byok_rag( + client, query, vector_store_ids, max_chunks=pool_size + ) solr_chunks_task = _fetch_solr_rag(client, query, solr) - (byok_chunks, byok_docs), (solr_chunks, solr_docs) = await asyncio.gather( + (byok_chunks, _), (solr_chunks, _) = await asyncio.gather( byok_chunks_task, solr_chunks_task ) - # Merge chunks from all sources (BYOK + Solr) - context_chunks = byok_chunks + solr_chunks + # Merge: combine and sort by score, keep top 2*BYOK_RAG_MAX_CHUNKS + merged = byok_chunks + solr_chunks + merged.sort( + key=lambda c: c.score if c.score is not None else float("-inf"), + reverse=True, + ) + merged = merged[:pool_size] + + # Rerank full pool with cross-encoder if enabled; boost BYOK then take top_k + if configuration.reranker.enabled: + reranked = await _rerank_chunks_with_cross_encoder( + query, merged, pool_size, model_name=configuration.reranker.model + ) + context_chunks = _apply_byok_rerank_boost(reranked)[:top_k] + else: + # Skip reranking, just apply boost and take top_k from original scores + context_chunks = _apply_byok_rerank_boost(merged)[:top_k] context_text = _format_rag_context(context_chunks, query) logger.debug( - "Inline RAG context built: %d chunks, %d characters", + "Inline RAG context built: %d chunks (after rerank), %d characters", len(context_chunks), len(context_text), ) - # Merge referenced documents from all sources (BYOK + Solr) - top_documents = byok_docs + solr_docs + # Referenced documents from final chunks only (after reranking) + top_documents = _referenced_documents_from_rag_chunks(context_chunks) return RAGContext( context_text=context_text, @@ -602,7 +813,8 @@ def _build_document_url( Build document URL based on offline flag and available metadata. Args: - offline: Whether to use offline mode (parent_id) or online mode (reference_url) + offline: Whether to use offline + (parent_id) or online mode (reference_url) doc_id: Document ID from chunk metadata reference_url: Reference URL from chunk metadata From d664680d8517e13c031d46e6a0acfa37f3cee4f1 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Mon, 27 Apr 2026 14:33:08 -0400 Subject: [PATCH 2/3] remove llamastack neural reranker and add unit tests Signed-off-by: Anxhela Coba --- src/models/config.py | 8 - .../config/test_reranker_configuration.py | 75 +++ tests/unit/utils/test_vector_search.py | 434 ++++++++++++++++++ 3 files changed, 509 insertions(+), 8 deletions(-) create mode 100644 tests/unit/models/config/test_reranker_configuration.py diff --git a/src/models/config.py b/src/models/config.py index e540aea24..a95b1a5fd 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1822,8 +1822,6 @@ class RerankerConfiguration(ConfigurationBase): description="Cross-encoder model name for reranking RAG chunks. " "Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.", ) - model_id: str = "meta-llama/llama-cross-encoder-base" - provider_id: str = "meta-reference" top_k_multiplier: float = 2.0 # fetch 2x, rerank, keep top_k byok_boost: float = 1.2 okp_boost: float = 1.0 @@ -1838,8 +1836,6 @@ def mark_as_explicitly_configured(self) -> Self: # This allows auto-enabling when user hasn't touched reranker settings # Check if any field differs from default values default_model = "cross-encoder/ms-marco-MiniLM-L6-v2" - default_model_id = "meta-llama/llama-cross-encoder-base" - default_provider_id = "meta-reference" default_top_k_multiplier = 2.0 default_byok_boost = 1.2 default_okp_boost = 1.0 @@ -1848,8 +1844,6 @@ def mark_as_explicitly_configured(self) -> Self: current_values = [ self.enabled, self.model, - self.model_id, - self.provider_id, self.top_k_multiplier, self.byok_boost, self.okp_boost, @@ -1857,8 +1851,6 @@ def mark_as_explicitly_configured(self) -> Self: default_values = [ True, default_model, - default_model_id, - default_provider_id, default_top_k_multiplier, default_byok_boost, default_okp_boost, diff --git a/tests/unit/models/config/test_reranker_configuration.py b/tests/unit/models/config/test_reranker_configuration.py new file mode 100644 index 000000000..ce362e3b2 --- /dev/null +++ b/tests/unit/models/config/test_reranker_configuration.py @@ -0,0 +1,75 @@ +"""Unit tests for RerankerConfiguration model.""" + +import pytest +from pydantic import ValidationError + +from models.config import RerankerConfiguration + + +class TestRerankerConfiguration: + """Tests for RerankerConfiguration model.""" + + def test_default_values(self) -> None: + """Test that RerankerConfiguration has correct default values.""" + config = RerankerConfiguration() + assert config.enabled is True + assert config.model == "cross-encoder/ms-marco-MiniLM-L6-v2" + assert config.top_k_multiplier == 2.0 + assert config.byok_boost == 1.2 + assert config.okp_boost == 1.0 + + def test_custom_model(self) -> None: + """Test configuration with custom cross-encoder model.""" + config = RerankerConfiguration( + model="cross-encoder/ms-marco-TinyBERT-L2-v2" + ) + assert config.model == "cross-encoder/ms-marco-TinyBERT-L2-v2" + assert config.enabled is True + + def test_disabled_reranker(self) -> None: + """Test configuration with reranker disabled.""" + config = RerankerConfiguration(enabled=False) + assert config.enabled is False + assert config.model == "cross-encoder/ms-marco-MiniLM-L6-v2" + + def test_custom_boost_factors(self) -> None: + """Test configuration with custom boost factors.""" + config = RerankerConfiguration( + byok_boost=1.5, + okp_boost=0.8 + ) + assert config.byok_boost == 1.5 + assert config.okp_boost == 0.8 + + def test_custom_top_k_multiplier(self) -> None: + """Test configuration with custom top_k_multiplier.""" + config = RerankerConfiguration(top_k_multiplier=3.0) + assert config.top_k_multiplier == 3.0 + + def test_all_custom_values(self) -> None: + """Test configuration with all custom values.""" + config = RerankerConfiguration( + enabled=False, + model="custom-cross-encoder", + top_k_multiplier=1.5, + byok_boost=2.0, + okp_boost=0.5 + ) + assert config.enabled is False + assert config.model == "custom-cross-encoder" + assert config.top_k_multiplier == 1.5 + assert config.byok_boost == 2.0 + assert config.okp_boost == 0.5 + + def test_explicit_configuration_detection(self) -> None: + """Test that explicitly configured values are detected.""" + # Non-default values should mark as explicitly configured + config = RerankerConfiguration(enabled=False) + assert hasattr(config, "_explicitly_configured") + # Note: The actual _explicitly_configured logic is private + # and tested through integration tests + + def test_invalid_field_rejected(self) -> None: + """Test that invalid fields are rejected due to extra='forbid'.""" + with pytest.raises(ValidationError): + RerankerConfiguration(invalid_field="value") \ No newline at end of file diff --git a/tests/unit/utils/test_vector_search.py b/tests/unit/utils/test_vector_search.py index 2aafab0a7..a78022361 100644 --- a/tests/unit/utils/test_vector_search.py +++ b/tests/unit/utils/test_vector_search.py @@ -1,5 +1,7 @@ """Unit tests for vector search utilities.""" +import sys +import unittest.mock import pytest from pydantic import AnyUrl from pytest_mock import MockerFixture @@ -9,6 +11,7 @@ from models.requests import SolrVectorSearchRequest from utils.types import RAGChunk from utils.vector_search import ( + _apply_byok_rerank_boost, _build_document_url, _build_query_params, _convert_solr_chunks_to_rag_format, @@ -17,9 +20,11 @@ _fetch_byok_rag, _fetch_solr_rag, _format_rag_context, + _get_cross_encoder, _get_okp_base_url, _get_solr_vector_store_ids, _is_solr_enabled, + _rerank_chunks_with_cross_encoder, build_rag_context, ) @@ -719,3 +724,432 @@ async def test_byok_enabled_only(self, mocker: MockerFixture) -> None: assert len(context.rag_chunks) > 0 assert "BYOK content" in context.context_text assert "file_search found" in context.context_text + + @pytest.mark.asyncio + async def test_reranker_enabled_calls_cross_encoder(self, mocker: MockerFixture) -> None: + """Test that cross-encoder is called when reranker is enabled.""" + # Mock configuration with reranker enabled + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.inline_solr_enabled = False + config_mock.score_multiplier_mapping = {"vs_1": 1.0} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + config_mock.reranker.enabled = True + config_mock.reranker.model = "test-model" + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock BYOK search response + chunk_mock = mocker.Mock() + chunk_mock.content = "BYOK content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Mock cross-encoder reranking function + mock_rerank = mocker.patch("utils.vector_search._rerank_chunks_with_cross_encoder") + mock_rerank.return_value = [ + RAGChunk(content="BYOK content", source="rag_1", score=0.95) + ] + + context = await build_rag_context(client_mock, "passed", "test query", None) + + # Verify cross-encoder was called + mock_rerank.assert_called_once() + assert mock_rerank.call_args[0][0] == "test query" # query parameter + assert mock_rerank.call_args[1]["model_name"] == "test-model" # model_name parameter + + assert len(context.rag_chunks) > 0 + + @pytest.mark.asyncio + async def test_reranker_disabled_skips_cross_encoder(self, mocker: MockerFixture) -> None: + """Test that cross-encoder is skipped when reranker is disabled.""" + # Mock configuration with reranker disabled + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.inline_solr_enabled = False + config_mock.score_multiplier_mapping = {"vs_1": 1.0} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + config_mock.reranker.enabled = False + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock BYOK search response + chunk_mock = mocker.Mock() + chunk_mock.content = "BYOK content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Mock cross-encoder reranking function + mock_rerank = mocker.patch("utils.vector_search._rerank_chunks_with_cross_encoder") + + context = await build_rag_context(client_mock, "passed", "test query", None) + + # Verify cross-encoder was NOT called + mock_rerank.assert_not_called() + + assert len(context.rag_chunks) > 0 + + +class TestGetCrossEncoder: + """Tests for _get_cross_encoder function.""" + + def test_loads_model_successfully(self, mocker: MockerFixture) -> None: + """Test successful model loading and caching.""" + # Clear the cache for testing + from utils.vector_search import _cross_encoder_models + _cross_encoder_models.clear() + + # Mock the CrossEncoder import and class + mock_cross_encoder_class = mocker.Mock() + mock_model_instance = mocker.Mock() + mock_cross_encoder_class.return_value = mock_model_instance + + # Mock the import of sentence_transformers + mock_sentence_transformers = mocker.Mock() + mock_sentence_transformers.CrossEncoder = mock_cross_encoder_class + + with unittest.mock.patch.dict(sys.modules, {"sentence_transformers": mock_sentence_transformers}): + model = _get_cross_encoder("test-model") + + assert model == mock_model_instance + mock_cross_encoder_class.assert_called_once_with("test-model") + + def test_caches_loaded_model(self, mocker: MockerFixture) -> None: + """Test that models are cached and not reloaded.""" + # Clear the cache for testing + from utils.vector_search import _cross_encoder_models + _cross_encoder_models.clear() + + mock_cross_encoder_class = mocker.Mock() + mock_model_instance = mocker.Mock() + mock_cross_encoder_class.return_value = mock_model_instance + + mock_sentence_transformers = mocker.Mock() + mock_sentence_transformers.CrossEncoder = mock_cross_encoder_class + + with unittest.mock.patch.dict(sys.modules, {"sentence_transformers": mock_sentence_transformers}): + # First call should load the model + model1 = _get_cross_encoder("test-model") + # Second call should return cached model + model2 = _get_cross_encoder("test-model") + + assert model1 == model2 == mock_model_instance + mock_cross_encoder_class.assert_called_once() # Only called once + + def test_handles_import_error(self, mocker: MockerFixture) -> None: + """Test graceful handling of sentence_transformers import error.""" + # Clear the cache for testing + from utils.vector_search import _cross_encoder_models + _cross_encoder_models.clear() + + # Mock the import to fail + mocker.patch("utils.vector_search.logger") # Suppress warning logs + + # Remove the module if it exists and patch import to fail + original_modules = sys.modules.copy() + if "sentence_transformers" in sys.modules: + del sys.modules["sentence_transformers"] + + def mock_import(name, *args): + if name == "sentence_transformers": + raise ImportError("Module not found") + return original_modules.get(name) + + mocker.patch("builtins.__import__", side_effect=mock_import) + model = _get_cross_encoder("test-model") + + # Restore original modules + sys.modules.update(original_modules) + + assert model is None + + def test_handles_model_loading_error(self, mocker: MockerFixture) -> None: + """Test graceful handling of model instantiation error.""" + # Clear the cache for testing + from utils.vector_search import _cross_encoder_models + _cross_encoder_models.clear() + + mock_cross_encoder_class = mocker.Mock() + mock_cross_encoder_class.side_effect = Exception("Model loading failed") + + mock_sentence_transformers = mocker.Mock() + mock_sentence_transformers.CrossEncoder = mock_cross_encoder_class + + with unittest.mock.patch.dict(sys.modules, {"sentence_transformers": mock_sentence_transformers}): + model = _get_cross_encoder("test-model") + + assert model is None + + +class TestRerankChunksWithCrossEncoder: + """Tests for _rerank_chunks_with_cross_encoder function.""" + + @pytest.mark.asyncio + async def test_empty_chunks(self) -> None: + """Test reranking with empty chunks list.""" + result = await _rerank_chunks_with_cross_encoder("test query", [], 5) + assert result == [] + + @pytest.mark.asyncio + async def test_successful_reranking(self, mocker: MockerFixture) -> None: + """Test successful reranking with score normalization.""" + # Create test chunks + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + RAGChunk(content="Content 3", source="source_3", score=0.8), + ] + + # Mock cross-encoder model and prediction + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5, 1.0, 3.0] # Raw scores + + # Mock _get_cross_encoder to return our mock model + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 3) + + # Verify model was called with correct pairs + expected_pairs = [ + ("test query", "Content 1"), + ("test query", "Content 2"), + ("test query", "Content 3"), + ] + mock_model.predict.assert_called_once_with(expected_pairs) + + # Verify results are sorted by normalized scores (highest first) + assert len(result) == 3 + assert result[0].content == "Content 3" # Score 3.0 -> normalized to 1.0 + assert result[1].content == "Content 1" # Score 2.5 -> normalized to 0.75 + assert result[2].content == "Content 2" # Score 1.0 -> normalized to 0.0 + + # Verify scores are normalized to [0,1] + assert result[0].score == 1.0 + assert result[1].score == 0.75 + assert result[2].score == 0.0 + + @pytest.mark.asyncio + async def test_top_k_limiting(self, mocker: MockerFixture) -> None: + """Test that top_k limits the number of returned chunks.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + RAGChunk(content="Content 3", source="source_3", score=0.8), + ] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5, 1.0, 3.0] + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 2) + + assert len(result) == 2 # Limited to top_k=2 + assert result[0].content == "Content 3" + assert result[1].content == "Content 1" + + @pytest.mark.asyncio + async def test_identical_scores_normalization(self, mocker: MockerFixture) -> None: + """Test normalization when all scores are identical.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + ] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [1.5, 1.5] # Identical scores + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # All scores should be 0.5 when identical + assert len(result) == 2 + assert result[0].score == 0.5 + assert result[1].score == 0.5 + + @pytest.mark.asyncio + async def test_single_chunk_normalization(self, mocker: MockerFixture) -> None: + """Test normalization with single chunk.""" + chunks = [RAGChunk(content="Content 1", source="source_1", score=0.5)] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5] + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 1) + + # Single chunk should get score 1.0 + assert len(result) == 1 + assert result[0].score == 1.0 + + @pytest.mark.asyncio + async def test_model_loading_failure_fallback(self, mocker: MockerFixture) -> None: + """Test fallback to original scores when model loading fails.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.8), + RAGChunk(content="Content 2", source="source_2", score=0.6), + ] + + # Mock _get_cross_encoder to return None (loading failed) + mocker.patch("utils.vector_search._get_cross_encoder", return_value=None) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # Should return chunks sorted by original scores + assert len(result) == 2 + assert result[0].content == "Content 1" # Higher original score + assert result[1].content == "Content 2" + assert result[0].score == 0.8 # Original scores preserved + assert result[1].score == 0.6 + + @pytest.mark.asyncio + async def test_prediction_failure_fallback(self, mocker: MockerFixture) -> None: + """Test fallback when model.predict() raises exception.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.9), + RAGChunk(content="Content 2", source="source_2", score=0.7), + ] + + mock_model = mocker.Mock() + mock_model.predict.side_effect = Exception("Prediction failed") + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # Should fallback to original scores + assert len(result) == 2 + assert result[0].content == "Content 1" + assert result[0].score == 0.9 + + @pytest.mark.asyncio + async def test_numpy_array_scores(self, mocker: MockerFixture) -> None: + """Test handling of numpy array scores from model prediction.""" + chunks = [RAGChunk(content="Content 1", source="source_1", score=0.5)] + + # Mock numpy array with tolist() method + mock_scores = mocker.Mock() + mock_scores.tolist.return_value = [2.5] + + mock_model = mocker.Mock() + mock_model.predict.return_value = mock_scores + mocker.patch("utils.vector_search._get_cross_encoder", return_value=mock_model) + + result = await _rerank_chunks_with_cross_encoder("test query", chunks, 1) + + # Should successfully handle numpy array conversion + assert len(result) == 1 + assert result[0].score == 1.0 + mock_scores.tolist.assert_called_once() + + +class TestApplyByokRerankBoost: + """Tests for _apply_byok_rerank_boost function.""" + + def test_empty_chunks(self) -> None: + """Test boost application with empty chunks list.""" + result = _apply_byok_rerank_boost([]) + assert result == [] + + def test_boost_byok_chunks_only(self) -> None: + """Test that only BYOK chunks (non-OKP) get boosted.""" + chunks = [ + RAGChunk(content="BYOK content", source="byok_store", score=0.8), + RAGChunk(content="OKP content", source=constants.OKP_RAG_ID, score=0.6), + RAGChunk(content="Another BYOK", source="another_store", score=0.7), + ] + + result = _apply_byok_rerank_boost(chunks, boost=2.0) + + assert len(result) == 3 + + # Find chunks by content for assertion + byok_chunk = next(c for c in result if c.content == "BYOK content") + okp_chunk = next(c for c in result if c.content == "OKP content") + another_byok = next(c for c in result if c.content == "Another BYOK") + + # BYOK chunks should be boosted + assert byok_chunk.score == 1.6 # 0.8 * 2.0 + assert another_byok.score == 1.4 # 0.7 * 2.0 + + # OKP chunk should remain unchanged + assert okp_chunk.score == 0.6 + + def test_sorting_by_boosted_scores(self) -> None: + """Test that chunks are sorted by boosted scores in descending order.""" + chunks = [ + RAGChunk(content="Low BYOK", source="byok_store", score=0.5), + RAGChunk(content="High OKP", source=constants.OKP_RAG_ID, score=0.9), + RAGChunk(content="Mid BYOK", source="another_store", score=0.7), + ] + + result = _apply_byok_rerank_boost(chunks, boost=2.0) + + # After boosting: Low BYOK=1.0, High OKP=0.9, Mid BYOK=1.4 + # Sorted order should be: Mid BYOK (1.4), Low BYOK (1.0), High OKP (0.9) + assert result[0].content == "Mid BYOK" + assert result[1].content == "Low BYOK" + assert result[2].content == "High OKP" + + def test_default_boost_factor(self) -> None: + """Test that default boost factor is applied correctly.""" + chunks = [RAGChunk(content="BYOK content", source="byok_store", score=0.8)] + + result = _apply_byok_rerank_boost(chunks) # Using default boost + + # Default boost should be constants.BYOK_RAG_RERANK_BOOST (1.2) + assert result[0].score == 0.8 * constants.BYOK_RAG_RERANK_BOOST + + def test_none_scores_handled(self) -> None: + """Test handling of chunks with None scores.""" + chunks = [ + RAGChunk(content="BYOK with score", source="byok_store", score=0.8), + RAGChunk(content="BYOK no score", source="byok_store", score=None), + RAGChunk(content="OKP no score", source=constants.OKP_RAG_ID, score=None), + ] + + result = _apply_byok_rerank_boost(chunks, boost=2.0) + + assert len(result) == 3 + + # Chunks with None scores should be treated as negative infinity for sorting + # but actual score calculation should handle None -> float("-inf") conversion + byok_with_score = next(c for c in result if c.content == "BYOK with score") + assert byok_with_score.score == 1.6 # 0.8 * 2.0 + + def test_preserves_chunk_attributes(self) -> None: + """Test that chunk attributes are preserved during boosting.""" + chunks = [ + RAGChunk( + content="Test content", + source="byok_store", + score=0.8, + attributes={"title": "Test Doc", "url": "http://example.com"} + ) + ] + + result = _apply_byok_rerank_boost(chunks, boost=1.5) + + assert len(result) == 1 + assert result[0].content == "Test content" + assert result[0].source == "byok_store" + assert abs(result[0].score - 1.2) < 1e-10 # 0.8 * 1.5 + assert result[0].attributes == {"title": "Test Doc", "url": "http://example.com"} From 9eb794ad91181715df4a56aa5d34b189a721f137 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Mon, 27 Apr 2026 15:34:02 -0400 Subject: [PATCH 3/3] update tests and clean code Signed-off-by: Anxhela Coba --- pyproject.toml | 2 - src/constants.py | 2 +- src/models/config.py | 50 ++++--------------- .../config/test_reranker_configuration.py | 28 ++--------- 4 files changed, 17 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80ed3a4e3..bc98a1aa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,8 +75,6 @@ dependencies = [ # Used for RAG chunk reranking (cross-encoder) "sentence-transformers>=2.0.0", "datasets>=4.7.0", - # Used for RAG chunk reranking (cross-encoder) - "sentence-transformers>=2.0.0", # Used for error tracking and monitoring "sentry-sdk[fastapi]>=2.58.0", ] diff --git a/src/constants.py b/src/constants.py index ed4d79b22..06fbba241 100644 --- a/src/constants.py +++ b/src/constants.py @@ -193,7 +193,7 @@ BYOK_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from BYOK RAG OKP_RAG_MAX_CHUNKS: Final[int] = 5 # retrieved from OKP RAG # Score multiplier applied to BYOK chunks after cross-encoder reranking (Solr chunks unchanged) -BYOK_RAG_RERANK_BOOST = 1.2 +BYOK_RAG_RERANK_BOOST: Final[float] = 1.2 # Solr OKP constants SOLR_VECTOR_SEARCH_DEFAULT_K: Final[int] = 5 diff --git a/src/models/config.py b/src/models/config.py index a95b1a5fd..15398f11e 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1822,9 +1822,6 @@ class RerankerConfiguration(ConfigurationBase): description="Cross-encoder model name for reranking RAG chunks. " "Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.", ) - top_k_multiplier: float = 2.0 # fetch 2x, rerank, keep top_k - byok_boost: float = 1.2 - okp_boost: float = 1.0 # Private attribute to track if this was explicitly configured _explicitly_configured: bool = PrivateAttr(default=False) @@ -1832,32 +1829,9 @@ class RerankerConfiguration(ConfigurationBase): @model_validator(mode="after") def mark_as_explicitly_configured(self) -> Self: """Mark this configuration as explicitly set when instantiated from user input.""" - # Only mark as explicitly configured if we're not using all default values - # This allows auto-enabling when user hasn't touched reranker settings - # Check if any field differs from default values - default_model = "cross-encoder/ms-marco-MiniLM-L6-v2" - default_top_k_multiplier = 2.0 - default_byok_boost = 1.2 - default_okp_boost = 1.0 - - # Check if any setting differs from defaults (indicates explicit configuration) - current_values = [ - self.enabled, - self.model, - self.top_k_multiplier, - self.byok_boost, - self.okp_boost, - ] - default_values = [ - True, - default_model, - default_top_k_multiplier, - default_byok_boost, - default_okp_boost, - ] - - if current_values != default_values: + if self.model_fields_set: self._explicitly_configured = True + return self @@ -2154,18 +2128,16 @@ def validate_reranker_auto_enable(self) -> Self: if ( has_byok and has_okp - and not hasattr(self.reranker, "_explicitly_configured") + and not self.reranker.enabled ): - # pylint: disable=no-member - if not self.reranker.enabled: - logger.info( - "Automatically enabling reranker: Both BYOK RAG (%d entries) and OKP " - "are configured. Reranking improves result quality when multiple " - "knowledge sources are available.", - len(self.byok_rag), - ) - # pylint: disable=no-member - self.reranker.enabled = True + + logger.info( + "Automatically enabling reranker: Both BYOK RAG (%d entries) and OKP " + "are configured. Reranking improves result quality when multiple " + "knowledge sources are available.", + len(self.byok_rag), + ) + self.reranker.enabled = True return self diff --git a/tests/unit/models/config/test_reranker_configuration.py b/tests/unit/models/config/test_reranker_configuration.py index ce362e3b2..25b7b7063 100644 --- a/tests/unit/models/config/test_reranker_configuration.py +++ b/tests/unit/models/config/test_reranker_configuration.py @@ -14,9 +14,6 @@ def test_default_values(self) -> None: config = RerankerConfiguration() assert config.enabled is True assert config.model == "cross-encoder/ms-marco-MiniLM-L6-v2" - assert config.top_k_multiplier == 2.0 - assert config.byok_boost == 1.2 - assert config.okp_boost == 1.0 def test_custom_model(self) -> None: """Test configuration with custom cross-encoder model.""" @@ -32,34 +29,19 @@ def test_disabled_reranker(self) -> None: assert config.enabled is False assert config.model == "cross-encoder/ms-marco-MiniLM-L6-v2" - def test_custom_boost_factors(self) -> None: - """Test configuration with custom boost factors.""" - config = RerankerConfiguration( - byok_boost=1.5, - okp_boost=0.8 - ) - assert config.byok_boost == 1.5 - assert config.okp_boost == 0.8 - - def test_custom_top_k_multiplier(self) -> None: - """Test configuration with custom top_k_multiplier.""" - config = RerankerConfiguration(top_k_multiplier=3.0) - assert config.top_k_multiplier == 3.0 + def test_model_fields_set_detection(self) -> None: + """Test that model_fields_set is properly detected.""" + config = RerankerConfiguration(model="custom-model") + assert config.model == "custom-model" def test_all_custom_values(self) -> None: """Test configuration with all custom values.""" config = RerankerConfiguration( enabled=False, - model="custom-cross-encoder", - top_k_multiplier=1.5, - byok_boost=2.0, - okp_boost=0.5 + model="custom-cross-encoder" ) assert config.enabled is False assert config.model == "custom-cross-encoder" - assert config.top_k_multiplier == 1.5 - assert config.byok_boost == 2.0 - assert config.okp_boost == 0.5 def test_explicit_configuration_detection(self) -> None: """Test that explicitly configured values are detected."""