-
Notifications
You must be signed in to change notification settings - Fork 7
[DEPRIORITIZED][AAQ-765] Retry LLM generation when AlignScore fails #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,14 @@ | |
import os | ||
from typing import Tuple | ||
|
||
import backoff | ||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status | ||
from fastapi.responses import JSONResponse | ||
from sqlalchemy.exc import IntegrityError | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from ..auth.dependencies import authenticate_key, rate_limiter | ||
from ..config import SPEECH_ENDPOINT | ||
from ..config import ALIGN_SCORE_N_RETRIES, SPEECH_ENDPOINT | ||
from ..contents.models import ( | ||
get_similar_content_async, | ||
increment_query_count, | ||
|
@@ -42,6 +43,7 @@ | |
) | ||
from .schemas import ( | ||
ContentFeedback, | ||
ErrorType, | ||
QueryBase, | ||
QueryRefined, | ||
QueryResponse, | ||
|
@@ -216,6 +218,17 @@ async def search( | |
contents=response.search_results, | ||
asession=asession, | ||
) | ||
if is_unable_to_generate_response(response): | ||
failure_reason = response.debug_info["factual_consistency"] | ||
response = await retry_search( | ||
query_refined=user_query_refined_template, | ||
response=response_template, | ||
user_id=user_db.user_id, | ||
n_similar=int(N_TOP_CONTENT), | ||
asession=asession, | ||
exclude_archived=True, | ||
) | ||
response.debug_info["past_failure"] = failure_reason | ||
|
||
if type(response) is QueryResponse: | ||
return response | ||
|
@@ -234,8 +247,8 @@ async def search( | |
@classify_safety__before | ||
@translate_question__before | ||
@paraphrase_question__before | ||
@generate_llm_response__after | ||
@check_align_score__after | ||
@generate_llm_response__after | ||
async def search_base( | ||
query_refined: QueryRefined, | ||
response: QueryResponse, | ||
|
@@ -288,6 +301,39 @@ async def search_base( | |
return response | ||
|
||
|
||
def is_unable_to_generate_response(response: QueryResponse) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this function retry only if that condition is met. |
||
""" | ||
Check if the response is of type QueryResponseError and caused | ||
by low alignment score. | ||
""" | ||
return ( | ||
isinstance(response, QueryResponseError) | ||
and response.error_type == ErrorType.ALIGNMENT_TOO_LOW | ||
) | ||
|
||
|
||
@backoff.on_predicate( | ||
backoff.expo, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just have a logic that retries once, instead of adding a config (num retries) we don't know if we'll use 🤔 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it depends on how useful the approach is, since we haven't done any analysis to see how well it works. |
||
max_tries=int(ALIGN_SCORE_N_RETRIES), | ||
predicate=is_unable_to_generate_response, | ||
) | ||
async def retry_search( | ||
suzinyou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
query_refined: QueryRefined, | ||
response: QueryResponse | QueryResponseError, | ||
user_id: int, | ||
n_similar: int, | ||
asession: AsyncSession, | ||
exclude_archived: bool = True, | ||
) -> QueryResponse | QueryResponseError: | ||
""" | ||
Retry wrapper for search_base. | ||
""" | ||
|
||
return await search_base( | ||
query_refined, response, user_id, n_similar, asession, exclude_archived | ||
) | ||
|
||
|
||
async def get_user_query_and_response( | ||
user_id: int, user_query: QueryBase, asession: AsyncSession | ||
) -> Tuple[QueryDB, QueryRefined, QueryResponse]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ types-openpyxl==3.1.4.20240621 | |
redis==5.0.8 | ||
python-dateutil==2.8.2 | ||
gTTS==2.5.1 | ||
backoff==2.2.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added that in case it works after the second try to understand why it failed first.