diff --git a/api/routers/messages.py b/api/routers/messages.py index b4c55cf..14366f2 100644 --- a/api/routers/messages.py +++ b/api/routers/messages.py @@ -215,12 +215,14 @@ async def send_message(req: MessageRequest, request: Request, auth_data: dict = raise HTTPException(status_code=409, detail="Duplicate request already in progress.") idempotency_claimed = True + is_pro = await check_is_pro(user_id) estimated_tokens = estimate_tokens_for_text(content) client_ip = _resolve_client_ip(request, trusted_proxies=trusted_proxies) await enforce_request_controls( user_id=user_id, client_ip=client_ip, estimated_tokens=estimated_tokens, + is_pro=is_pro, ) supabase = get_supabase_admin() @@ -283,7 +285,6 @@ async def send_message(req: MessageRequest, request: Request, auth_data: dict = if prompt_mode not in SUPPORTED_PROMPT_MODES: prompt_mode = normalize_prompt_level(None) - is_pro = await check_is_pro(user_id) if selected_mode == TECHNICAL_MODE and not is_pro: raise HTTPException(status_code=403, detail="Technical mode is a Pro feature") request_temperature = max(0.0, min(float(req.temperature), 1.0)) @@ -746,38 +747,49 @@ async def close_stream(stream): retry=bool(req.regenerate), sampled=False, ) + if aborted: + await cache_set( + idempotency_key, + {"status": "failed", "message_id": client_message_id}, + ttl=idempotency_ttl_seconds, + ) + return + if full_content.strip(): + if not req.regenerate and not response_truncated: + await cache_set(cache_key, {"response": full_content}, ttl=cache_ttl_seconds) + await cache_set( + idempotency_key, + { + "status": "completed", + "response": full_content, + "assistant_message_id": assistant_message_id, + "mode": selected_mode, + "prompt_mode": prompt_mode, + "partial": True, + }, + ttl=idempotency_ttl_seconds, + ) + mode_label = "" + if selected_mode == TECHNICAL_MODE: + mode_label = "technical " + elif selected_mode == SOCRATIC_MODE: + mode_label = "socratic " + yield emit( + "delta", + { + "delta": f"\n\n[Connection interrupted. Partial {mode_label}response delivered.]", + "assistant_message_id": assistant_message_id, + }, + ) + yield emit("done", "[DONE]") + return await cache_set( idempotency_key, {"status": "failed", "message_id": client_message_id}, ttl=idempotency_ttl_seconds, ) - if not aborted: - if full_content.strip(): - if not req.regenerate and not response_truncated: - await cache_set(cache_key, {"response": full_content}, ttl=cache_ttl_seconds) - await cache_set( - idempotency_key, - { - "status": "completed", - "response": full_content, - "assistant_message_id": assistant_message_id, - "mode": selected_mode, - "prompt_mode": prompt_mode, - "partial": True, - }, - ttl=idempotency_ttl_seconds, - ) - yield emit( - "delta", - { - "delta": "\n\n[Connection interrupted. Partial technical response delivered.]", - "assistant_message_id": assistant_message_id, - }, - ) - yield emit("done", "[DONE]") - return - yield emit("error", {"error": "Streaming failed"}) - yield emit("done", "[DONE]") + yield emit("error", {"error": "Streaming failed"}) + yield emit("done", "[DONE]") finally: total_ms = (time.perf_counter() - start_time) * 1000 avg_chunk_interval_ms = None diff --git a/api/routers/query.py b/api/routers/query.py index a81b6e4..c4f91f9 100644 --- a/api/routers/query.py +++ b/api/routers/query.py @@ -179,6 +179,7 @@ async def query_topic( user_id=str(effective_user_id) if effective_user_id else None, client_ip=request.client.host if request.client else "unknown", estimated_tokens=estimated_tokens, + is_pro=is_verified_pro, ) explanations: dict[str, str] = {} @@ -322,6 +323,7 @@ async def query_topic_stream( user_id=str(effective_user_id) if effective_user_id else None, client_ip=request.client.host if request.client else "unknown", estimated_tokens=estimated_tokens, + is_pro=is_verified_pro, ) message_id = None diff --git a/api/services/rate_limit.py b/api/services/rate_limit.py index 27ca788..5504d2a 100644 --- a/api/services/rate_limit.py +++ b/api/services/rate_limit.py @@ -196,6 +196,7 @@ async def enforce_request_controls( user_id: str | None, client_ip: str | None, estimated_tokens: int, + is_pro: bool = False, ) -> None: """Apply auth-scoped quota, distributed rate limiting, and circuit breaker checks. @@ -209,7 +210,7 @@ async def enforce_request_controls( is_authenticated = bool(user_id) fail_open = is_authenticated - if is_authenticated: + if is_authenticated and not is_pro: try: quota_result = await check_daily_quota(user_id=str(user_id), estimated_tokens=estimated_tokens) except Exception as exc: diff --git a/src/api.ts b/src/api.ts index 05d5901..7687799 100644 --- a/src/api.ts +++ b/src/api.ts @@ -8,6 +8,8 @@ import type { import { LegacyStreamChunkSchema } from "./lib/sseSchemas"; import type { Session } from "@supabase/supabase-js"; import { getTracePropagationHeaders } from "./lib/monitoring"; +import type { ApiError } from "./lib/httpErrors"; +import { buildApiError } from "./lib/httpErrors"; const API_URL = import.meta.env.VITE_API_URL || ""; const SUPABASE_CONFIGURED = @@ -67,6 +69,7 @@ async function fetchAPI( path: string, options?: RequestInit & { responseType?: "json" | "blob" }, ): Promise { + let timeoutFired = false; const session = await getSupabaseSession(); const headers: Record = { "Content-Type": "application/json", @@ -85,7 +88,10 @@ async function fetchAPI( headers["x-request-id"] = createRequestId(); const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 90000); // 90 seconds + const timeoutId = setTimeout(() => { + timeoutFired = true; + controller.abort(); + }, 90000); // 90 seconds const externalSignal = options?.signal; const abortSignalAny = ( AbortSignal as unknown as { @@ -123,11 +129,9 @@ async function fetchAPI( }); cleanup(); - if (res.status === 429) - throw new Error( - "You are sending requests too quickly. Please wait a moment.", - ); - if (!res.ok) throw new Error(`API error: ${res.status}`); + if (!res.ok) { + throw await buildApiError(res); + } if (options?.responseType === "blob") { return (await res.blob()) as unknown as T; @@ -135,8 +139,12 @@ async function fetchAPI( return await res.json(); } catch (err) { cleanup(); - if (isAbortError(err)) - throw new Error("Request timed out. Please try again."); + if (isAbortError(err)) { + if (timeoutFired) { + throw new Error("Request timed out. Please try again."); + } + throw normalizeError(err); + } throw normalizeError(err); } } @@ -148,10 +156,14 @@ export async function getPinnedTopics(): Promise { export async function getHealth(): Promise { return fetchAPI("/api/health"); } -export async function queryTopic(req: QueryRequest): Promise { +export async function queryTopic( + req: QueryRequest, + signal?: AbortSignal, +): Promise { return fetchAPI("/api/query", { method: "POST", body: JSON.stringify(req), + signal, }); } @@ -177,12 +189,15 @@ export async function queryTopicStream( const baseDelay = 750; const fallbackToNonStream = async (reason: string): Promise => { + if (signal?.aborted) { + return; + } try { console.warn( "Streaming unavailable, falling back to non-stream response:", reason, ); - const data = await queryTopic(req); + const data = await queryTopic(req, signal); const preferredLevel = req.levels?.[0]; const levelKey = preferredLevel && data.explanations?.[preferredLevel] @@ -208,7 +223,7 @@ export async function queryTopicStream( }); if (!response.ok) { - throw new Error(`API error: ${response.status}`); + throw await buildApiError(response); } // Validate SSE content type @@ -241,10 +256,14 @@ export async function queryTopicStream( ); }); - const { done, value } = await Promise.race([ - readPromise, - timeoutPromise, - ]); + let readResult: ReadableStreamReadResult; + try { + readResult = await Promise.race([readPromise, timeoutPromise]); + } catch (e) { + reader.cancel().catch(() => {}); + throw e; + } + const { done, value } = readResult; if (timeoutId) { clearTimeout(timeoutId); } @@ -305,12 +324,18 @@ export async function queryTopicStream( return; } + const error = normalizeError(err) as ApiError; + const retryAllowed = error.detail?.retry_allowed !== false; + if (!retryAllowed) { + onError(error); + return; + } + // Retry on network errors if not aborted if (retries < maxRetries && !signal?.aborted) { retries++; const delay = Math.min(8000, baseDelay * 2 ** (retries - 1)) + Math.random() * 250; - const error = normalizeError(err); console.warn( `Stream failed, retry ${retries}/${maxRetries} in ${Math.round(delay)}ms:`, error.message, @@ -319,7 +344,6 @@ export async function queryTopicStream( return attemptStream(); } - const error = normalizeError(err); await fallbackToNonStream(error.message || "Stream failed"); } }; diff --git a/src/lib/httpErrors.ts b/src/lib/httpErrors.ts new file mode 100644 index 0000000..5987fd9 --- /dev/null +++ b/src/lib/httpErrors.ts @@ -0,0 +1,53 @@ +export type ApiErrorDetail = { + type?: string; + retry_allowed?: boolean; + limit?: number; + consumed?: number; + scope?: string; +}; + +export type ApiError = Error & { + status?: number; + detail?: ApiErrorDetail; +}; + +const messageFromDetail = (detail: ApiErrorDetail): string => { + if (detail.type === "quota_exceeded") { + return "Daily quota exceeded. Please try again after your quota resets."; + } + if (detail.type === "rate_limit_exceeded") { + return "You are sending requests too quickly. Please wait a moment."; + } + return ""; +}; + +export const buildApiError = async (response: Response): Promise => { + let message = ""; + let detail: ApiErrorDetail | undefined; + + try { + const payload = (await response.json()) as Record; + const payloadDetail = payload.detail; + const payloadError = payload.error; + + if (typeof payloadDetail === "string" && payloadDetail.trim()) { + message = payloadDetail.trim(); + } else if (typeof payloadError === "string" && payloadError.trim()) { + message = payloadError.trim(); + } else if (payloadDetail && typeof payloadDetail === "object") { + detail = payloadDetail as ApiErrorDetail; + message = messageFromDetail(detail); + } + } catch { + // ignore non-json error payloads + } + + const err = new Error( + message || `API error: ${response.status}`, + ) as ApiError; + err.status = response.status; + if (detail) { + err.detail = detail; + } + return err; +}; diff --git a/src/services/chatService.ts b/src/services/chatService.ts index f9985e1..4ce5226 100644 --- a/src/services/chatService.ts +++ b/src/services/chatService.ts @@ -5,6 +5,7 @@ import { getTracePropagationHeaders } from "../lib/monitoring"; import { toQueryLevel } from "../lib/chatModes"; import type { ChatMode, PromptMode } from "../types/chat"; import { API_URL, createUuid, supabaseConfigured } from "../lib/chatStoreUtils"; +import { buildApiError } from "../lib/httpErrors"; interface SendChatParams { conversationId: string; @@ -40,26 +41,8 @@ const buildHeaders = async (): Promise> => { return headers; }; -const buildHttpError = async (response: Response): Promise => { - let message = ""; - try { - const payload = (await response.json()) as Record; - const detail = payload.detail; - const error = payload.error; - if (typeof detail === "string" && detail.trim()) { - message = detail.trim(); - } else if (typeof error === "string" && error.trim()) { - message = error.trim(); - } - } catch { - // ignore non-json error payloads - } - - const err = new Error( - message || `Request failed with status ${response.status}`, - ) as Error & { status?: number }; - err.status = response.status; - return err; +const buildHttpError = async (response: Response) => { + return buildApiError(response); }; const handlePayload = ( diff --git a/src/stores/useChatStore.ts b/src/stores/useChatStore.ts index 5b8241e..1f81aea 100644 --- a/src/stores/useChatStore.ts +++ b/src/stores/useChatStore.ts @@ -12,6 +12,7 @@ import { trackTelemetry, } from "../lib/monitoring"; import { sendChat } from "../services/chatService"; +import type { ApiError } from "../lib/httpErrors"; import { makeLocalId, makeClientId, @@ -638,7 +639,13 @@ export const useChatStore = create((set, get) => ({ return; } + const apiError = error as ApiError; + const errorDetail = apiError.detail; + const retryAllowed = errorDetail?.retry_allowed !== false; let errorMessage = getErrorMessage(error, "Failed to send message"); + if (errorDetail?.type === "quota_exceeded") { + errorMessage = "Daily quota exceeded. Please try again after your quota resets."; + } if (/timed out/i.test(errorMessage)) errorMessage = "Streaming timed out. Retry."; if (/duplicate request already in progress/i.test(errorMessage)) { errorMessage = "Retry will send a new request."; @@ -650,15 +657,17 @@ export const useChatStore = create((set, get) => ({ ); notifyError(errorMessage); - cachePendingSync({ - id: assistantClientId, - content: trimmed, - mode: requestedMode, - promptMode: effectivePromptMode, - createdAt: new Date().toISOString(), - clientMessageId, - assistantClientId, - }); + if (retryAllowed) { + cachePendingSync({ + id: assistantClientId, + content: trimmed, + mode: requestedMode, + promptMode: effectivePromptMode, + createdAt: new Date().toISOString(), + clientMessageId, + assistantClientId, + }); + } useMessageStore.getState().updateMessageByClientId(assistantClientId, (msg) => ({ ...msg, @@ -666,14 +675,16 @@ export const useChatStore = create((set, get) => ({ isRegenerating: false, error: errorMessage, syncStatus: "failed", - retryPayload: { - content: trimmed, - mode: requestedMode, - promptMode: effectivePromptMode, - temperature: requestTemperature, - clientMessageId, - assistantClientId, - }, + retryPayload: retryAllowed + ? { + content: trimmed, + mode: requestedMode, + promptMode: effectivePromptMode, + temperature: requestTemperature, + clientMessageId, + assistantClientId, + } + : undefined, })); } finally { controller.abort();