Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions api/routers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,14 @@ async def send_message(req: MessageRequest, request: Request, auth_data: dict =
raise HTTPException(status_code=409, detail="Duplicate request already in progress.")
idempotency_claimed = True

is_pro = await check_is_pro(user_id)
estimated_tokens = estimate_tokens_for_text(content)
client_ip = _resolve_client_ip(request, trusted_proxies=trusted_proxies)
await enforce_request_controls(
user_id=user_id,
client_ip=client_ip,
estimated_tokens=estimated_tokens,
is_pro=is_pro,
)

supabase = get_supabase_admin()
Expand Down Expand Up @@ -283,7 +285,6 @@ async def send_message(req: MessageRequest, request: Request, auth_data: dict =
if prompt_mode not in SUPPORTED_PROMPT_MODES:
prompt_mode = normalize_prompt_level(None)

is_pro = await check_is_pro(user_id)
if selected_mode == TECHNICAL_MODE and not is_pro:
raise HTTPException(status_code=403, detail="Technical mode is a Pro feature")
request_temperature = max(0.0, min(float(req.temperature), 1.0))
Expand Down Expand Up @@ -746,38 +747,49 @@ async def close_stream(stream):
retry=bool(req.regenerate),
sampled=False,
)
if aborted:
await cache_set(
idempotency_key,
{"status": "failed", "message_id": client_message_id},
ttl=idempotency_ttl_seconds,
)
return
if full_content.strip():
if not req.regenerate and not response_truncated:
await cache_set(cache_key, {"response": full_content}, ttl=cache_ttl_seconds)
await cache_set(
idempotency_key,
{
"status": "completed",
"response": full_content,
"assistant_message_id": assistant_message_id,
"mode": selected_mode,
"prompt_mode": prompt_mode,
"partial": True,
},
ttl=idempotency_ttl_seconds,
)
mode_label = ""
if selected_mode == TECHNICAL_MODE:
mode_label = "technical "
elif selected_mode == SOCRATIC_MODE:
mode_label = "socratic "
yield emit(
"delta",
{
"delta": f"\n\n[Connection interrupted. Partial {mode_label}response delivered.]",
"assistant_message_id": assistant_message_id,
},
)
yield emit("done", "[DONE]")
return
await cache_set(
idempotency_key,
{"status": "failed", "message_id": client_message_id},
ttl=idempotency_ttl_seconds,
)
if not aborted:
if full_content.strip():
if not req.regenerate and not response_truncated:
await cache_set(cache_key, {"response": full_content}, ttl=cache_ttl_seconds)
await cache_set(
idempotency_key,
{
"status": "completed",
"response": full_content,
"assistant_message_id": assistant_message_id,
"mode": selected_mode,
"prompt_mode": prompt_mode,
"partial": True,
},
ttl=idempotency_ttl_seconds,
)
yield emit(
"delta",
{
"delta": "\n\n[Connection interrupted. Partial technical response delivered.]",
"assistant_message_id": assistant_message_id,
},
)
yield emit("done", "[DONE]")
return
yield emit("error", {"error": "Streaming failed"})
yield emit("done", "[DONE]")
yield emit("error", {"error": "Streaming failed"})
yield emit("done", "[DONE]")
finally:
total_ms = (time.perf_counter() - start_time) * 1000
avg_chunk_interval_ms = None
Expand Down
2 changes: 2 additions & 0 deletions api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def query_topic(
user_id=str(effective_user_id) if effective_user_id else None,
client_ip=request.client.host if request.client else "unknown",
estimated_tokens=estimated_tokens,
is_pro=is_verified_pro,
)

explanations: dict[str, str] = {}
Expand Down Expand Up @@ -322,6 +323,7 @@ async def query_topic_stream(
user_id=str(effective_user_id) if effective_user_id else None,
client_ip=request.client.host if request.client else "unknown",
estimated_tokens=estimated_tokens,
is_pro=is_verified_pro,
)

message_id = None
Expand Down
3 changes: 2 additions & 1 deletion api/services/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ async def enforce_request_controls(
user_id: str | None,
client_ip: str | None,
estimated_tokens: int,
is_pro: bool = False,
) -> None:
"""Apply auth-scoped quota, distributed rate limiting, and circuit breaker checks.

Expand All @@ -209,7 +210,7 @@ async def enforce_request_controls(
is_authenticated = bool(user_id)
fail_open = is_authenticated

if is_authenticated:
if is_authenticated and not is_pro:
try:
quota_result = await check_daily_quota(user_id=str(user_id), estimated_tokens=estimated_tokens)
except Exception as exc:
Expand Down
58 changes: 41 additions & 17 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import type {
import { LegacyStreamChunkSchema } from "./lib/sseSchemas";
import type { Session } from "@supabase/supabase-js";
import { getTracePropagationHeaders } from "./lib/monitoring";
import type { ApiError } from "./lib/httpErrors";
import { buildApiError } from "./lib/httpErrors";

const API_URL = import.meta.env.VITE_API_URL || "";
const SUPABASE_CONFIGURED =
Expand Down Expand Up @@ -67,6 +69,7 @@ async function fetchAPI<T>(
path: string,
options?: RequestInit & { responseType?: "json" | "blob" },
): Promise<T> {
let timeoutFired = false;
const session = await getSupabaseSession();
const headers: Record<string, string> = {
"Content-Type": "application/json",
Expand All @@ -85,7 +88,10 @@ async function fetchAPI<T>(
headers["x-request-id"] = createRequestId();

const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 90000); // 90 seconds
const timeoutId = setTimeout(() => {
timeoutFired = true;
controller.abort();
}, 90000); // 90 seconds
const externalSignal = options?.signal;
const abortSignalAny = (
AbortSignal as unknown as {
Expand Down Expand Up @@ -123,20 +129,22 @@ async function fetchAPI<T>(
});
cleanup();

if (res.status === 429)
throw new Error(
"You are sending requests too quickly. Please wait a moment.",
);
if (!res.ok) throw new Error(`API error: ${res.status}`);
if (!res.ok) {
throw await buildApiError(res);
}

if (options?.responseType === "blob") {
return (await res.blob()) as unknown as T;
}
return await res.json();
} catch (err) {
cleanup();
if (isAbortError(err))
throw new Error("Request timed out. Please try again.");
if (isAbortError(err)) {
if (timeoutFired) {
throw new Error("Request timed out. Please try again.");
}
throw normalizeError(err);
}
throw normalizeError(err);
}
}
Expand All @@ -148,10 +156,14 @@ export async function getPinnedTopics(): Promise<PinnedTopic[]> {
export async function getHealth(): Promise<HealthResponse> {
return fetchAPI("/api/health");
}
export async function queryTopic(req: QueryRequest): Promise<QueryResponse> {
export async function queryTopic(
req: QueryRequest,
signal?: AbortSignal,
): Promise<QueryResponse> {
return fetchAPI("/api/query", {
method: "POST",
body: JSON.stringify(req),
signal,
});
}

Expand All @@ -177,12 +189,15 @@ export async function queryTopicStream(
const baseDelay = 750;

const fallbackToNonStream = async (reason: string): Promise<void> => {
if (signal?.aborted) {
return;
}
try {
console.warn(
"Streaming unavailable, falling back to non-stream response:",
reason,
);
const data = await queryTopic(req);
const data = await queryTopic(req, signal);
const preferredLevel = req.levels?.[0];
const levelKey =
preferredLevel && data.explanations?.[preferredLevel]
Expand All @@ -208,7 +223,7 @@ export async function queryTopicStream(
});

if (!response.ok) {
throw new Error(`API error: ${response.status}`);
throw await buildApiError(response);
}

// Validate SSE content type
Expand Down Expand Up @@ -241,10 +256,14 @@ export async function queryTopicStream(
);
});

const { done, value } = await Promise.race([
readPromise,
timeoutPromise,
]);
let readResult: ReadableStreamReadResult<Uint8Array>;
try {
readResult = await Promise.race([readPromise, timeoutPromise]);
} catch (e) {
reader.cancel().catch(() => {});
throw e;
}
const { done, value } = readResult;
if (timeoutId) {
clearTimeout(timeoutId);
}
Expand Down Expand Up @@ -305,12 +324,18 @@ export async function queryTopicStream(
return;
}

const error = normalizeError(err) as ApiError;
const retryAllowed = error.detail?.retry_allowed !== false;
if (!retryAllowed) {
onError(error);
return;
}

// Retry on network errors if not aborted
if (retries < maxRetries && !signal?.aborted) {
retries++;
const delay =
Math.min(8000, baseDelay * 2 ** (retries - 1)) + Math.random() * 250;
const error = normalizeError(err);
console.warn(
`Stream failed, retry ${retries}/${maxRetries} in ${Math.round(delay)}ms:`,
error.message,
Expand All @@ -319,7 +344,6 @@ export async function queryTopicStream(
return attemptStream();
}

const error = normalizeError(err);
await fallbackToNonStream(error.message || "Stream failed");
}
};
Expand Down
53 changes: 53 additions & 0 deletions src/lib/httpErrors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
export type ApiErrorDetail = {
type?: string;
retry_allowed?: boolean;
limit?: number;
consumed?: number;
scope?: string;
};

export type ApiError = Error & {
status?: number;
detail?: ApiErrorDetail;
};

const messageFromDetail = (detail: ApiErrorDetail): string => {
if (detail.type === "quota_exceeded") {
return "Daily quota exceeded. Please try again after your quota resets.";
}
if (detail.type === "rate_limit_exceeded") {
return "You are sending requests too quickly. Please wait a moment.";
}
return "";
};

export const buildApiError = async (response: Response): Promise<ApiError> => {
let message = "";
let detail: ApiErrorDetail | undefined;

try {
const payload = (await response.json()) as Record<string, unknown>;
const payloadDetail = payload.detail;
const payloadError = payload.error;

if (typeof payloadDetail === "string" && payloadDetail.trim()) {
message = payloadDetail.trim();
} else if (typeof payloadError === "string" && payloadError.trim()) {
message = payloadError.trim();
} else if (payloadDetail && typeof payloadDetail === "object") {
detail = payloadDetail as ApiErrorDetail;
message = messageFromDetail(detail);
}
} catch {
// ignore non-json error payloads
}

const err = new Error(
message || `API error: ${response.status}`,
) as ApiError;
err.status = response.status;
if (detail) {
err.detail = detail;
}
return err;
};
23 changes: 3 additions & 20 deletions src/services/chatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { getTracePropagationHeaders } from "../lib/monitoring";
import { toQueryLevel } from "../lib/chatModes";
import type { ChatMode, PromptMode } from "../types/chat";
import { API_URL, createUuid, supabaseConfigured } from "../lib/chatStoreUtils";
import { buildApiError } from "../lib/httpErrors";

interface SendChatParams {
conversationId: string;
Expand Down Expand Up @@ -40,26 +41,8 @@ const buildHeaders = async (): Promise<Record<string, string>> => {
return headers;
};

const buildHttpError = async (response: Response): Promise<Error & { status?: number }> => {
let message = "";
try {
const payload = (await response.json()) as Record<string, unknown>;
const detail = payload.detail;
const error = payload.error;
if (typeof detail === "string" && detail.trim()) {
message = detail.trim();
} else if (typeof error === "string" && error.trim()) {
message = error.trim();
}
} catch {
// ignore non-json error payloads
}

const err = new Error(
message || `Request failed with status ${response.status}`,
) as Error & { status?: number };
err.status = response.status;
return err;
const buildHttpError = async (response: Response) => {
return buildApiError(response);
};

const handlePayload = (
Expand Down
Loading
Loading