Skip to content

[DEPRIORITIZED][AAQ-765] Retry LLM generation when AlignScore fails #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core_backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
LITELLM_MODEL_GENERATION = os.environ.get(
"LITELLM_MODEL_GENERATION",
"openai/generate-gemini-response",
# "LITELLM_MODEL_GENERATION", "openai/generate-response"
# "LITELLM_MODEL_GENERATION",
# "openai/generate-response",
)
LITELLM_MODEL_LANGUAGE_DETECT = os.environ.get(
"LITELLM_MODEL_LANGUAGE_DETECT", "openai/detect-language"
Expand Down Expand Up @@ -64,6 +65,7 @@
ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "LLM")
# if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above.
ALIGN_SCORE_API = os.environ.get("ALIGN_SCORE_API", "")
ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 1)

# Backend paths
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "")
Expand Down
2 changes: 1 addition & 1 deletion core_backend/app/llm_call/process_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def wrapper(

response = await func(query_refined, response, *args, **kwargs)

if not kwargs.get("generate_llm_response", False):
if not query_refined.generate_llm_response:
return response

metadata = create_langfuse_metadata(
Expand Down
50 changes: 48 additions & 2 deletions core_backend/app/question_answer/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import os
from typing import Tuple

import backoff
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from ..auth.dependencies import authenticate_key, rate_limiter
from ..config import SPEECH_ENDPOINT
from ..config import ALIGN_SCORE_N_RETRIES, SPEECH_ENDPOINT
from ..contents.models import (
get_similar_content_async,
increment_query_count,
Expand Down Expand Up @@ -42,6 +43,7 @@
)
from .schemas import (
ContentFeedback,
ErrorType,
QueryBase,
QueryRefined,
QueryResponse,
Expand Down Expand Up @@ -216,6 +218,17 @@ async def search(
contents=response.search_results,
asession=asession,
)
if is_unable_to_generate_response(response):
failure_reason = response.debug_info["factual_consistency"]
response = await retry_search(
query_refined=user_query_refined_template,
response=response_template,
user_id=user_db.user_id,
n_similar=int(N_TOP_CONTENT),
asession=asession,
exclude_archived=True,
)
response.debug_info["past_failure"] = failure_reason
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in case it works after the second try to understand why it failed first.


if type(response) is QueryResponse:
return response
Expand All @@ -234,8 +247,8 @@ async def search(
@classify_safety__before
@translate_question__before
@paraphrase_question__before
@generate_llm_response__after
@check_align_score__after
@generate_llm_response__after
async def search_base(
query_refined: QueryRefined,
response: QueryResponse,
Expand Down Expand Up @@ -288,6 +301,39 @@ async def search_base(
return response


def is_unable_to_generate_response(response: QueryResponse) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this function retry only if that condition is met.

"""
Check if the response is of type QueryResponseError and caused
by low alignment score.
"""
return (
isinstance(response, QueryResponseError)
and response.error_type == ErrorType.ALIGNMENT_TOO_LOW
)


@backoff.on_predicate(
backoff.expo,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What backoff.expo does is basically waiting a little more everytime the function is reran in an exponential way just to handle the load better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just have a logic that retries once, instead of adding a config (num retries) we don't know if we'll use 🤔 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it depends on how useful the approach is, since we haven't done any analysis to see how well it works.
But personnally I think since it doesn't add a dependency (backoff being used by litellm), and since the only code we would change if we retry just once is the decorator and the config variable, the cost is pretty low, so we can just keep it.

max_tries=int(ALIGN_SCORE_N_RETRIES),
predicate=is_unable_to_generate_response,
)
async def retry_search(
query_refined: QueryRefined,
response: QueryResponse | QueryResponseError,
user_id: int,
n_similar: int,
asession: AsyncSession,
exclude_archived: bool = True,
) -> QueryResponse | QueryResponseError:
"""
Retry wrapper for search_base.
"""

return await search_base(
query_refined, response, user_id, n_similar, asession, exclude_archived
)


async def get_user_query_and_response(
user_id: int, user_query: QueryBase, asession: AsyncSession
) -> Tuple[QueryDB, QueryRefined, QueryResponse]:
Expand Down
1 change: 1 addition & 0 deletions core_backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ types-openpyxl==3.1.4.20240621
redis==5.0.8
python-dateutil==2.8.2
gTTS==2.5.1
backoff==2.2.1
Loading