diff --git a/.gitignore b/.gitignore index 9de3e2574..7da9d1140 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ documents/warmup_ocr.pdf documents/openrag-documentation.pdf documents/ibm_anthropic.pdf documents/docling.pdf +/opensearch-data-new-lf diff --git a/Makefile b/Makefile index ac37c4a9a..f5d391e8f 100644 --- a/Makefile +++ b/Makefile @@ -665,9 +665,11 @@ test-unit: ## Run unit tests only @echo "$(PURPLE)Unit tests complete.$(NC)" test-integration: ## Run integration tests (requires infrastructure) - @echo "$(YELLOW)Running integration tests (requires infrastructure)...$(NC)" - @echo "$(CYAN)Make sure to run 'make dev-local' first!$(NC)" - uv run pytest tests/integration/ -v + @echo "$(CYAN)════════════════════════════════════════$(NC)" + @echo "$(PURPLE) Core Integration Tests$(NC)" + @echo "$(CYAN)════════════════════════════════════════$(NC)" + @echo "$(YELLOW)Make sure to run 'make dev-local' first!$(NC)" + uv run pytest tests/integration/core/ -v test-ci: ## Start infra, run integration + SDK tests, tear down (uses DockerHub images) @set -e; \ @@ -681,12 +683,15 @@ test-ci: ## Start infra, run integration + SDK tests, tear down (uses DockerHub chmod 600 keys/private_key.pem 2>/dev/null || true; \ chmod 644 keys/public_key.pem 2>/dev/null || true; \ fi; \ + echo "::group::Cleanup, Pull & Build Images"; \ echo "$(YELLOW)Cleaning up old containers and volumes...$(NC)"; \ $(COMPOSE_CMD) down -v 2>/dev/null || true; \ echo "$(YELLOW)Pulling latest images...$(NC)"; \ $(COMPOSE_CMD) pull; \ echo "$(YELLOW)Building OpenSearch image override...$(NC)"; \ $(CONTAINER_RUNTIME) build --no-cache -t langflowai/openrag-opensearch:latest -f Dockerfile .; \ + echo "::endgroup::"; \ + echo "::group::Start Infrastructure"; \ echo "$(YELLOW)Starting infra (OpenSearch + Dashboards + Langflow + Backend + Frontend) with CPU containers$(NC)"; \ OPENSEARCH_HOST=opensearch $(COMPOSE_CMD) up -d opensearch dashboards langflow openrag-backend openrag-frontend; \ echo "$(CYAN)Architecture: $$(uname -m), Platform: $$(uname -s)$(NC)"; \ @@ -752,31 +757,42 @@ test-ci: ## Start infra, run integration + SDK tests, tear down (uses DockerHub $(COMPOSE_CMD) down -v 2>/dev/null || true; \ exit 1; \ fi; \ - echo "$(PURPLE)Running integration tests$(NC)"; \ + echo "::endgroup::"; \ + echo "::group::Core Integration Tests"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) Core Integration Tests$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \ GOOGLE_OAUTH_CLIENT_ID="" \ GOOGLE_OAUTH_CLIENT_SECRET="" \ OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \ OPENSEARCH_USERNAME=admin OPENSEARCH_PASSWORD=$${OPENSEARCH_PASSWORD} \ DISABLE_STARTUP_INGEST=$${DISABLE_STARTUP_INGEST:-true} \ - uv run pytest tests/integration -vv -s -o log_cli=true --log-cli-level=DEBUG; \ + uv run pytest tests/integration/core -vv -s -o log_cli=true --log-cli-level=DEBUG; \ TEST_RESULT=$$?; \ + echo "::endgroup::"; \ echo ""; \ echo "$(YELLOW)Waiting for frontend at http://localhost:3000...$(NC)"; \ for i in $$(seq 1 60); do \ curl -s http://localhost:3000/ >/dev/null 2>&1 && break || sleep 2; \ done; \ - echo "$(PURPLE)Running Python SDK integration tests$(NC)"; \ - cd sdks/python && \ - uv sync --extra dev && \ - OPENRAG_URL=http://localhost:3000 uv run pytest tests/test_integration.py -vv -s || TEST_RESULT=1; \ - cd ../..; \ - echo "$(PURPLE)Running TypeScript SDK integration tests$(NC)"; \ + echo "::group::SDK Integration Tests (Python)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) SDK Integration Tests (Python)$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + uv pip install -e sdks/python; \ + SDK_TESTS_ONLY=true OPENRAG_URL=http://localhost:3000 uv run pytest tests/integration/sdk/ -vv -s || TEST_RESULT=1; \ + echo "::endgroup::"; \ + echo "::group::SDK Integration Tests (TypeScript)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) SDK Integration Tests (TypeScript)$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ cd sdks/typescript && \ npm install && npm run build && \ OPENRAG_URL=http://localhost:3000 npm test || TEST_RESULT=1; \ cd ../..; \ - echo "$(CYAN)=================================$(NC)"; \ + echo "::endgroup::"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ echo ""; \ ($(call test_jwt_opensearch)) || TEST_RESULT=1; \ echo "$(YELLOW)Tearing down infra$(NC)"; \ @@ -796,6 +812,7 @@ test-ci-local: ## Same as test-ci but builds all images locally chmod 600 keys/private_key.pem 2>/dev/null || true; \ chmod 644 keys/public_key.pem 2>/dev/null || true; \ fi; \ + echo "::group::Cleanup & Build Images"; \ echo "$(YELLOW)Cleaning up old containers and volumes...$(NC)"; \ $(COMPOSE_CMD) down -v 2>/dev/null || true; \ echo "$(YELLOW)Building all images locally...$(NC)"; \ @@ -803,6 +820,8 @@ test-ci-local: ## Same as test-ci but builds all images locally $(CONTAINER_RUNTIME) build -t langflowai/openrag-backend:latest -f Dockerfile.backend .; \ $(CONTAINER_RUNTIME) build -t langflowai/openrag-frontend:latest -f Dockerfile.frontend .; \ $(CONTAINER_RUNTIME) build -t langflowai/openrag-langflow:latest -f Dockerfile.langflow .; \ + echo "::endgroup::"; \ + echo "::group::Start Infrastructure"; \ echo "$(YELLOW)Starting infra (OpenSearch + Dashboards + Langflow + Backend + Frontend) with CPU containers$(NC)"; \ echo "$(CYAN)Architecture: $$(uname -m), Platform: $$(uname -s)$(NC)"; \ OPENSEARCH_HOST=opensearch $(COMPOSE_CMD) up -d opensearch dashboards langflow openrag-backend openrag-frontend; \ @@ -868,31 +887,42 @@ test-ci-local: ## Same as test-ci but builds all images locally $(COMPOSE_CMD) down -v 2>/dev/null || true; \ exit 1; \ fi; \ - echo "$(PURPLE)Running integration tests$(NC)"; \ + echo "::endgroup::"; \ + echo "::group::Core Integration Tests"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) Core Integration Tests$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ LOG_LEVEL=$${LOG_LEVEL:-DEBUG} \ GOOGLE_OAUTH_CLIENT_ID="" \ GOOGLE_OAUTH_CLIENT_SECRET="" \ OPENSEARCH_HOST=localhost OPENSEARCH_PORT=9200 \ OPENSEARCH_USERNAME=admin OPENSEARCH_PASSWORD=$${OPENSEARCH_PASSWORD} \ DISABLE_STARTUP_INGEST=$${DISABLE_STARTUP_INGEST:-true} \ - uv run pytest tests/integration -vv -s -o log_cli=true --log-cli-level=DEBUG; \ + uv run pytest tests/integration/core -vv -s -o log_cli=true --log-cli-level=DEBUG; \ TEST_RESULT=$$?; \ + echo "::endgroup::"; \ echo ""; \ echo "$(YELLOW)Waiting for frontend at http://localhost:3000...$(NC)"; \ for i in $$(seq 1 60); do \ curl -s http://localhost:3000/ >/dev/null 2>&1 && break || sleep 2; \ done; \ - echo "$(PURPLE)Running Python SDK integration tests$(NC)"; \ - cd sdks/python && \ - uv sync --extra dev && \ - OPENRAG_URL=http://localhost:3000 uv run pytest tests/test_integration.py -vv -s || TEST_RESULT=1; \ - cd ../..; \ - echo "$(PURPLE)Running TypeScript SDK integration tests$(NC)"; \ + echo "::group::SDK Integration Tests (Python)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) SDK Integration Tests (Python)$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + uv pip install -e sdks/python; \ + SDK_TESTS_ONLY=true OPENRAG_URL=http://localhost:3000 uv run pytest tests/integration/sdk/ -vv -s || TEST_RESULT=1; \ + echo "::endgroup::"; \ + echo "::group::SDK Integration Tests (TypeScript)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ + echo "$(PURPLE) SDK Integration Tests (TypeScript)$(NC)"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ cd sdks/typescript && \ npm install && npm run build && \ OPENRAG_URL=http://localhost:3000 npm test || TEST_RESULT=1; \ cd ../..; \ - echo "$(CYAN)=================================$(NC)"; \ + echo "::endgroup::"; \ + echo "$(CYAN)════════════════════════════════════════$(NC)"; \ echo ""; \ if [ $$TEST_RESULT -ne 0 ]; then \ echo "$(RED)=== Tests failed, dumping container logs ===$(NC)"; \ @@ -914,11 +944,12 @@ test-os-jwt: ## Test JWT authentication against OpenSearch @$(call test_jwt_opensearch) test-sdk: ## Run SDK integration tests (requires running OpenRAG at localhost:3000) - @echo "$(YELLOW)Running SDK integration tests...$(NC)" - @echo "$(CYAN)Make sure OpenRAG is running at localhost:3000 (make dev)$(NC)" - @echo "" - @echo "$(PURPLE)Running Python SDK tests...$(NC)" - cd sdks/python && uv sync --extra dev && OPENRAG_URL=http://localhost:3000 uv run pytest tests/test_integration.py -vv -s + @echo "$(CYAN)════════════════════════════════════════$(NC)" + @echo "$(PURPLE) SDK Integration Tests (Python)$(NC)" + @echo "$(CYAN)════════════════════════════════════════$(NC)" + @echo "$(YELLOW)Make sure OpenRAG is running at localhost:3000 (make dev)$(NC)" + uv pip install -e sdks/python + SDK_TESTS_ONLY=true OPENRAG_URL=http://localhost:3000 uv run pytest tests/integration/sdk/ -vv -s @echo "" @echo "$(PURPLE)Running TypeScript SDK tests...$(NC)" cd sdks/typescript && npm install && npm run build && OPENRAG_URL=http://localhost:3000 npm test diff --git a/docker-compose.yml b/docker-compose.yml index 1dcd35bd0..a4faf3692 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -80,6 +80,14 @@ services: - WEBHOOK_BASE_URL=${WEBHOOK_BASE_URL} - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + - AWS_S3_ENDPOINT=${AWS_S3_ENDPOINT} + - AWS_REGION=${AWS_REGION} + - IBM_COS_API_KEY=${IBM_COS_API_KEY} + - IBM_COS_SERVICE_INSTANCE_ID=${IBM_COS_SERVICE_INSTANCE_ID} + - IBM_COS_ENDPOINT=${IBM_COS_ENDPOINT} + - IBM_COS_HMAC_ACCESS_KEY_ID=${IBM_COS_HMAC_ACCESS_KEY_ID} + - IBM_COS_HMAC_SECRET_ACCESS_KEY=${IBM_COS_HMAC_SECRET_ACCESS_KEY} + - IBM_COS_AUTH_ENDPOINT=${IBM_COS_AUTH_ENDPOINT} - OPENSEARCH_INDEX_NAME=${OPENSEARCH_INDEX_NAME:-documents} - LANGFLOW_KEY=${LANGFLOW_KEY} - LANGFLOW_KEY_RETRIES=${LANGFLOW_KEY_RETRIES:-15} diff --git a/frontend/app/api/mutations/useConnectConnectorMutation.ts b/frontend/app/api/mutations/useConnectConnectorMutation.ts index 893a90086..6e0faf4b8 100644 --- a/frontend/app/api/mutations/useConnectConnectorMutation.ts +++ b/frontend/app/api/mutations/useConnectConnectorMutation.ts @@ -81,6 +81,11 @@ export const useConnectConnectorMutation = () => { `state=${result.connection_id}`; window.location.href = authUrl; + } else { + // Direct-auth connector (e.g. IBM COS) — credentials already verified, + // no OAuth redirect needed. Refresh connector status. + queryClient.invalidateQueries({ queryKey: ["connectors"] }); + toast.success(`${connector.name} connected successfully`); } }, }); diff --git a/frontend/app/api/mutations/useIBMCOSConfigureMutation.ts b/frontend/app/api/mutations/useIBMCOSConfigureMutation.ts new file mode 100644 index 000000000..89e0e23f4 --- /dev/null +++ b/frontend/app/api/mutations/useIBMCOSConfigureMutation.ts @@ -0,0 +1,40 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; + +export interface IBMCOSConfigurePayload { + auth_mode: "iam" | "hmac"; + endpoint: string; + // IAM + api_key?: string; + service_instance_id?: string; + auth_endpoint?: string; + // HMAC + hmac_access_key?: string; + hmac_secret_key?: string; + // Bucket selection + bucket_names?: string[]; + // Updating an existing connection + connection_id?: string; +} + +async function configureIBMCOS(payload: IBMCOSConfigurePayload) { + const res = await fetch("/api/connectors/ibm_cos/configure", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + const data = await res.json(); + if (!res.ok) throw new Error(data.error || "Failed to configure IBM COS"); + return data as { connection_id: string; status: string }; +} + +export function useIBMCOSConfigureMutation() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: configureIBMCOS, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ["connectors"] }); + queryClient.invalidateQueries({ queryKey: ["ibm-cos-defaults"] }); + }, + }); +} diff --git a/frontend/app/api/mutations/useS3ConfigureMutation.ts b/frontend/app/api/mutations/useS3ConfigureMutation.ts new file mode 100644 index 000000000..1e4e16a5e --- /dev/null +++ b/frontend/app/api/mutations/useS3ConfigureMutation.ts @@ -0,0 +1,33 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; + +export interface S3ConfigurePayload { + access_key?: string; + secret_key?: string; + endpoint_url?: string; + region?: string; + bucket_names?: string[]; + connection_id?: string; +} + +async function configureS3(payload: S3ConfigurePayload) { + const res = await fetch("/api/connectors/aws_s3/configure", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + const data = await res.json(); + if (!res.ok) throw new Error(data.error || "Failed to configure S3"); + return data as { connection_id: string; status: string }; +} + +export function useS3ConfigureMutation() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: configureS3, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ["connectors"] }); + queryClient.invalidateQueries({ queryKey: ["s3-defaults"] }); + }, + }); +} diff --git a/frontend/app/api/mutations/useSyncConnector.ts b/frontend/app/api/mutations/useSyncConnector.ts index 9ff22e47e..fe2c7ef2e 100644 --- a/frontend/app/api/mutations/useSyncConnector.ts +++ b/frontend/app/api/mutations/useSyncConnector.ts @@ -47,6 +47,10 @@ const syncConnector = async ({ size?: number; }>; settings?: any; + /** When true, ingest all files from the connector (bypasses the re-sync gate). */ + sync_all?: boolean; + /** Restrict ingest to these bucket names (IBM COS). */ + bucket_filter?: string[]; }; }): Promise => { const response = await fetch(`/api/connectors/${connectorType}/sync`, { diff --git a/frontend/app/api/queries/useIBMCOSBucketStatusQuery.ts b/frontend/app/api/queries/useIBMCOSBucketStatusQuery.ts new file mode 100644 index 000000000..2a4b81b89 --- /dev/null +++ b/frontend/app/api/queries/useIBMCOSBucketStatusQuery.ts @@ -0,0 +1,34 @@ +import { useQuery } from "@tanstack/react-query"; + +export interface IBMCOSBucketStatus { + name: string; + ingested_count: number; + is_synced: boolean; +} + +async function fetchIBMCOSBucketStatus( + connectionId: string, +): Promise { + const res = await fetch( + `/api/connectors/ibm_cos/${connectionId}/bucket-status`, + ); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.error || "Failed to fetch bucket status"); + } + const data = await res.json(); + return data.buckets as IBMCOSBucketStatus[]; +} + +export function useIBMCOSBucketStatusQuery( + connectionId: string | null | undefined, + options?: { enabled?: boolean }, +) { + return useQuery({ + queryKey: ["ibm-cos-bucket-status", connectionId], + queryFn: () => fetchIBMCOSBucketStatus(connectionId!), + enabled: (options?.enabled ?? true) && !!connectionId, + staleTime: 0, + refetchOnMount: "always", + }); +} diff --git a/frontend/app/api/queries/useIBMCOSBucketsQuery.ts b/frontend/app/api/queries/useIBMCOSBucketsQuery.ts new file mode 100644 index 000000000..a7daa619f --- /dev/null +++ b/frontend/app/api/queries/useIBMCOSBucketsQuery.ts @@ -0,0 +1,23 @@ +import { useQuery } from "@tanstack/react-query"; + +async function fetchIBMCOSBuckets(connectionId: string): Promise { + const res = await fetch(`/api/connectors/ibm_cos/${connectionId}/buckets`); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.error || "Failed to list buckets"); + } + const data = await res.json(); + return data.buckets as string[]; +} + +export function useIBMCOSBucketsQuery( + connectionId: string | null | undefined, + options?: { enabled?: boolean }, +) { + return useQuery({ + queryKey: ["ibm-cos-buckets", connectionId], + queryFn: () => fetchIBMCOSBuckets(connectionId!), + enabled: (options?.enabled ?? true) && !!connectionId, + staleTime: 30_000, + }); +} diff --git a/frontend/app/api/queries/useIBMCOSDefaultsQuery.ts b/frontend/app/api/queries/useIBMCOSDefaultsQuery.ts new file mode 100644 index 000000000..b44ee4c5b --- /dev/null +++ b/frontend/app/api/queries/useIBMCOSDefaultsQuery.ts @@ -0,0 +1,28 @@ +import { useQuery } from "@tanstack/react-query"; + +export interface IBMCOSDefaults { + api_key_set: boolean; + service_instance_id: string; + endpoint: string; + hmac_access_key_set: boolean; + hmac_secret_key_set: boolean; + auth_mode: "iam" | "hmac"; + bucket_names: string[]; + connection_id: string | null; + disable_iam: boolean; +} + +async function fetchIBMCOSDefaults(): Promise { + const res = await fetch("/api/connectors/ibm_cos/defaults"); + if (!res.ok) throw new Error("Failed to fetch IBM COS defaults"); + return res.json(); +} + +export function useIBMCOSDefaultsQuery(options?: { enabled?: boolean }) { + return useQuery({ + queryKey: ["ibm-cos-defaults"], + queryFn: fetchIBMCOSDefaults, + enabled: options?.enabled ?? true, + staleTime: 0, + }); +} diff --git a/frontend/app/api/queries/useS3BucketStatusQuery.ts b/frontend/app/api/queries/useS3BucketStatusQuery.ts new file mode 100644 index 000000000..595920933 --- /dev/null +++ b/frontend/app/api/queries/useS3BucketStatusQuery.ts @@ -0,0 +1,30 @@ +import { useQuery } from "@tanstack/react-query"; + +export interface S3BucketStatus { + name: string; + ingested_count: number; + is_synced: boolean; +} + +async function fetchS3BucketStatus(connectionId: string): Promise { + const res = await fetch(`/api/connectors/aws_s3/${connectionId}/bucket-status`); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.error || "Failed to fetch bucket status"); + } + const data = await res.json(); + return data.buckets as S3BucketStatus[]; +} + +export function useS3BucketStatusQuery( + connectionId: string | null | undefined, + options?: { enabled?: boolean }, +) { + return useQuery({ + queryKey: ["s3-bucket-status", connectionId], + queryFn: () => fetchS3BucketStatus(connectionId!), + enabled: (options?.enabled ?? true) && !!connectionId, + staleTime: 0, + refetchOnMount: "always", + }); +} diff --git a/frontend/app/api/queries/useS3DefaultsQuery.ts b/frontend/app/api/queries/useS3DefaultsQuery.ts new file mode 100644 index 000000000..9c7d24560 --- /dev/null +++ b/frontend/app/api/queries/useS3DefaultsQuery.ts @@ -0,0 +1,25 @@ +import { useQuery } from "@tanstack/react-query"; + +export interface S3Defaults { + access_key_set: boolean; + secret_key_set: boolean; + endpoint: string; + region: string; + bucket_names: string[]; + connection_id: string | null; +} + +async function fetchS3Defaults(): Promise { + const res = await fetch("/api/connectors/aws_s3/defaults"); + if (!res.ok) throw new Error("Failed to fetch S3 defaults"); + return res.json(); +} + +export function useS3DefaultsQuery(options?: { enabled?: boolean }) { + return useQuery({ + queryKey: ["s3-defaults"], + queryFn: fetchS3Defaults, + enabled: options?.enabled ?? true, + staleTime: 0, + }); +} diff --git a/frontend/app/knowledge/page.tsx b/frontend/app/knowledge/page.tsx index c88a996a1..9b88e855f 100644 --- a/frontend/app/knowledge/page.tsx +++ b/frontend/app/knowledge/page.tsx @@ -34,7 +34,10 @@ import { DeleteConfirmationDialog, formatFilesToDelete, } from "../../components/delete-confirmation-dialog"; +import AwsLogo from "../../components/icons/aws-logo"; import GoogleDriveIcon from "../../components/icons/google-drive-logo"; +import IBMCOSIcon from "../../components/icons/ibm-cos-icon"; +import IBMLogo from "../../components/icons/ibm-logo"; import OneDriveIcon from "../../components/icons/one-drive-logo"; import SharePointIcon from "../../components/icons/share-point-logo"; import { useDeleteDocument } from "../api/mutations/useDeleteDocument"; @@ -59,6 +62,10 @@ function getSourceIcon(connectorType?: string) { return ; case "s3": return ; + case "ibm_cos": + return ; + case "aws_s3": + return ; default: return ( diff --git a/frontend/app/settings/_components/anthropic-settings-dialog.tsx b/frontend/app/settings/_components/anthropic-settings-dialog.tsx index 4652a26e6..f120ad36c 100644 --- a/frontend/app/settings/_components/anthropic-settings-dialog.tsx +++ b/frontend/app/settings/_components/anthropic-settings-dialog.tsx @@ -138,7 +138,13 @@ const AnthropicSettingsDialog = ({ }; return ( - { setShowRemoveConfirm(false); setOpen(o); }}> + { + setShowRemoveConfirm(false); + setOpen(o); + }} + >
diff --git a/frontend/app/settings/_components/card-icon.tsx b/frontend/app/settings/_components/card-icon.tsx index 6f532b29d..ee9e86fbd 100644 --- a/frontend/app/settings/_components/card-icon.tsx +++ b/frontend/app/settings/_components/card-icon.tsx @@ -16,7 +16,7 @@ export default function CardIcon({ className={cn( "w-8 h-8 rounded flex items-center justify-center border", isActive - ? activeBgColor + ? `${activeBgColor} text-black` : "bg-muted grayscale group-hover:bg-background", )} > diff --git a/frontend/app/settings/_components/connector-card.tsx b/frontend/app/settings/_components/connector-card.tsx index fc7a41cfd..c3e2ab42e 100644 --- a/frontend/app/settings/_components/connector-card.tsx +++ b/frontend/app/settings/_components/connector-card.tsx @@ -1,6 +1,6 @@ "use client"; -import { Loader2, Plus, RefreshCcw, Trash2 } from "lucide-react"; +import { Loader2, Plus, RefreshCcw, Settings2, Trash2 } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; import { @@ -29,6 +29,8 @@ interface ConnectorCardProps { onConnect: (connector: Connector) => void; onDisconnect: (connector: Connector) => void; onNavigateToKnowledge: (connector: Connector) => void; + /** Optional: open a connector-specific settings/edit dialog */ + onConfigure?: (connector: Connector) => void; } export default function ConnectorCard({ @@ -38,11 +40,11 @@ export default function ConnectorCard({ onConnect, onDisconnect, onNavigateToKnowledge, + onConfigure, }: ConnectorCardProps) { console.log(connector); const isConnected = connector.status === "connected" && connector.connectionId; - connector?.status === "connected" && connector?.connectionId; return ( @@ -50,10 +52,7 @@ export default function ConnectorCard({
- + {connector.icon} {isConnected ? ( @@ -68,7 +67,7 @@ export default function ConnectorCard({ {connector.name} - {connector?.available + {isConnected || connector?.available ? `${connector.name} is configured.` : "Not configured."} @@ -93,7 +92,9 @@ export default function ConnectorCard({ + + + + + +
+ ); +} diff --git a/frontend/app/settings/_components/ibm-cos-settings-form.tsx b/frontend/app/settings/_components/ibm-cos-settings-form.tsx new file mode 100644 index 000000000..481aad35b --- /dev/null +++ b/frontend/app/settings/_components/ibm-cos-settings-form.tsx @@ -0,0 +1,305 @@ +"use client"; + +import { useFormContext, Controller } from "react-hook-form"; +import { LabelWrapper } from "@/components/label-wrapper"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Tabs, TabsList, TabsTrigger, TabsContent } from "@/components/ui/tabs"; +import { Loader2, RefreshCcw } from "lucide-react"; +import { Button } from "@/components/ui/button"; + +export interface IBMCOSFormData { + auth_mode: "iam" | "hmac"; + endpoint: string; + // IAM + api_key: string; + service_instance_id: string; + // HMAC + hmac_access_key: string; + hmac_secret_key: string; +} + +interface IBMCOSSettingsFormProps { + /** Available buckets after a successful test — null means not yet tested */ + buckets: string[] | null; + selectedBuckets: string[]; + onSelectedBucketsChange: (buckets: string[]) => void; + isFetchingBuckets: boolean; + bucketsError: string | null; + onTestConnection: () => void; + apiKeySet?: boolean; + hmacAccessKeySet?: boolean; + hmacSecretKeySet?: boolean; + formError?: string | null; + /** When true, IAM tab is greyed out and HMAC is the only selectable option */ + disableIam?: boolean; +} + +export function IBMCOSSettingsForm({ + buckets, + selectedBuckets, + onSelectedBucketsChange, + isFetchingBuckets, + bucketsError, + onTestConnection, + apiKeySet, + hmacAccessKeySet, + hmacSecretKeySet, + formError, + disableIam = false, +}: IBMCOSSettingsFormProps) { + const { + register, + control, + formState: { errors }, + } = useFormContext(); + + const toggleBucket = (name: string, checked: boolean) => { + if (checked) { + onSelectedBucketsChange([...selectedBuckets, name]); + } else { + onSelectedBucketsChange(selectedBuckets.filter((b) => b !== name)); + } + }; + + const toggleAll = (checked: boolean) => { + onSelectedBucketsChange(checked ? (buckets ?? []) : []); + }; + + return ( +
+ {/* Auth mode selector using Tabs */} +
+ + ( + { + if (disableIam && v === "iam") return; + field.onChange(v); + }} + > + + + HMAC + + Access Key + Secret Key + + + + IAM + + API Key + Resource Instance ID + + + + + {/* HMAC fields — first tab */} + +
+
+ + v?.trim() })} + id="ibm-cos-hmac-key" + type="password" + placeholder={ + hmacAccessKeySet + ? "•••••••• (loaded from env)" + : "cos_hmac_keys.access_key_id" + } + autoComplete="off" + /> + +
+
+ + v?.trim() })} + id="ibm-cos-hmac-secret" + type="password" + placeholder={ + hmacSecretKeySet + ? "•••••••• (loaded from env)" + : "cos_hmac_keys.secret_access_key" + } + autoComplete="off" + /> + +
+
+
+ + {/* IAM fields — second tab */} + +
+
+ + v?.trim() })} + id="ibm-cos-api-key" + type="password" + placeholder={ + apiKeySet + ? "•••••••• (loaded from env)" + : 'apikey value from Service Credentials' + } + autoComplete="off" + /> + +
+
+ + v?.trim() })} + id="ibm-cos-svc-id" + placeholder="crn:v1:bluemix:public:cloud-object-storage:..." + /> + +
+
+
+
+ )} + /> +
+ + {/* Endpoint — shared by both auth modes */} +
+ + v?.trim() })} + id="ibm-cos-endpoint" + placeholder="https://s3.us-south.cloud-object-storage.appdomain.cloud" + className={errors.endpoint ? "!border-destructive" : ""} + /> + + {errors.endpoint && ( +

{errors.endpoint.message}

+ )} +
+ + {/* Test connection */} + + + {bucketsError && ( +

+ {bucketsError} +

+ )} + + {formError && ( +

+ {formError} +

+ )} + + {/* Bucket selector — native checkboxes styled with Tailwind */} + {buckets !== null && ( +
+
+ + {buckets.length > 1 && ( + + )} +
+ + {buckets.length === 0 ? ( +

+ No buckets found for this account. +

+ ) : ( +
+ {buckets.map((bucket) => ( + + ))} +
+ )} +
+ )} +
+ ); +} diff --git a/frontend/app/settings/_components/ollama-settings-dialog.tsx b/frontend/app/settings/_components/ollama-settings-dialog.tsx index b60665bf3..e911cb2ab 100644 --- a/frontend/app/settings/_components/ollama-settings-dialog.tsx +++ b/frontend/app/settings/_components/ollama-settings-dialog.tsx @@ -17,11 +17,11 @@ import { DialogTitle, } from "@/components/ui/dialog"; import { useAuth } from "@/contexts/auth-context"; +import ModelProviderDialogFooter from "./model-provider-dialog-footer"; import { OllamaSettingsForm, type OllamaSettingsFormData, } from "./ollama-settings-form"; -import ModelProviderDialogFooter from "./model-provider-dialog-footer"; const OllamaSettingsDialog = ({ open, @@ -131,7 +131,13 @@ const OllamaSettingsDialog = ({ }; return ( - { setShowRemoveConfirm(false); setOpen(o); }}> + { + setShowRemoveConfirm(false); + setOpen(o); + }} + >
diff --git a/frontend/app/settings/_components/openai-settings-dialog.tsx b/frontend/app/settings/_components/openai-settings-dialog.tsx index 0eb08f9e4..eeb3e83f7 100644 --- a/frontend/app/settings/_components/openai-settings-dialog.tsx +++ b/frontend/app/settings/_components/openai-settings-dialog.tsx @@ -17,11 +17,11 @@ import { DialogTitle, } from "@/components/ui/dialog"; import { useAuth } from "@/contexts/auth-context"; +import ModelProviderDialogFooter from "./model-provider-dialog-footer"; import { OpenAISettingsForm, type OpenAISettingsFormData, } from "./openai-settings-form"; -import ModelProviderDialogFooter from "./model-provider-dialog-footer"; const OpenAISettingsDialog = ({ open, @@ -138,7 +138,13 @@ const OpenAISettingsDialog = ({ }; return ( - { setShowRemoveConfirm(false); setOpen(o); }}> + { + setShowRemoveConfirm(false); + setOpen(o); + }} + > diff --git a/frontend/app/settings/_components/s3-settings-dialog.tsx b/frontend/app/settings/_components/s3-settings-dialog.tsx new file mode 100644 index 000000000..1f32dd0e5 --- /dev/null +++ b/frontend/app/settings/_components/s3-settings-dialog.tsx @@ -0,0 +1,190 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { FormProvider, useForm } from "react-hook-form"; +import { toast } from "sonner"; +import { useQueryClient } from "@tanstack/react-query"; +import AwsLogo from "@/components/icons/aws-logo"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { S3SettingsForm, type S3FormData } from "./s3-settings-form"; +import { useS3DefaultsQuery } from "@/app/api/queries/useS3DefaultsQuery"; +import { useS3ConfigureMutation } from "@/app/api/mutations/useS3ConfigureMutation"; + +interface S3SettingsDialogProps { + open: boolean; + setOpen: (open: boolean) => void; +} + +export default function S3SettingsDialog({ + open, + setOpen, +}: S3SettingsDialogProps) { + const queryClient = useQueryClient(); + + const { data: defaults } = useS3DefaultsQuery({ enabled: open }); + + const methods = useForm({ + mode: "onSubmit", + values: { + access_key: "", + secret_key: "", + endpoint_url: defaults?.endpoint ?? "", + region: defaults?.region ?? "", + }, + }); + + const { handleSubmit } = methods; + + const [buckets, setBuckets] = useState( + defaults?.bucket_names?.length ? defaults.bucket_names : null, + ); + const [selectedBuckets, setSelectedBuckets] = useState( + defaults?.bucket_names ?? [], + ); + + // Sync bucket state when defaults load asynchronously after dialog mount + useEffect(() => { + if (defaults?.bucket_names?.length) { + setBuckets(defaults.bucket_names); + setSelectedBuckets((prev) => + prev.length ? prev : defaults.bucket_names, + ); + } + }, [defaults?.bucket_names?.join(",")]); + + const [isFetchingBuckets, setIsFetchingBuckets] = useState(false); + const [bucketsError, setBucketsError] = useState(null); + const [formError, setFormError] = useState(null); + + const configureMutation = useS3ConfigureMutation(); + + const handleTestConnection = handleSubmit(async (data) => { + setIsFetchingBuckets(true); + setBucketsError(null); + setFormError(null); + + try { + const result = await configureMutation.mutateAsync({ + access_key: data.access_key || undefined, + secret_key: data.secret_key || undefined, + endpoint_url: data.endpoint_url || undefined, + region: data.region || undefined, + connection_id: defaults?.connection_id ?? undefined, + }); + + const res = await fetch( + `/api/connectors/aws_s3/${result.connection_id}/buckets`, + ); + const json = await res.json(); + if (!res.ok) throw new Error(json.error || "Failed to list buckets"); + + const fetchedBuckets: string[] = json.buckets; + setBuckets(fetchedBuckets); + + setSelectedBuckets((prev) => + prev.filter((b) => fetchedBuckets.includes(b)), + ); + + queryClient.invalidateQueries({ queryKey: ["s3-defaults"] }); + } catch (err: any) { + setBucketsError(err.message ?? "Connection failed"); + } finally { + setIsFetchingBuckets(false); + } + }); + + const onSubmit = handleSubmit(async (data) => { + setFormError(null); + if (buckets === null) { + setFormError("Test the connection first to validate credentials."); + return; + } + + try { + const latestDefaults = await queryClient.fetchQuery({ + queryKey: ["s3-defaults"], + queryFn: async () => { + const res = await fetch("/api/connectors/aws_s3/defaults"); + return res.json(); + }, + staleTime: 0, + }); + + await configureMutation.mutateAsync({ + access_key: data.access_key || undefined, + secret_key: data.secret_key || undefined, + endpoint_url: data.endpoint_url || undefined, + region: data.region || undefined, + bucket_names: selectedBuckets, + connection_id: latestDefaults?.connection_id ?? defaults?.connection_id ?? undefined, + }); + + toast.success("Amazon S3 configured", { + description: + selectedBuckets.length > 0 + ? `Will ingest from: ${selectedBuckets.join(", ")}` + : "Will auto-discover and ingest all accessible buckets.", + icon: , + }); + + queryClient.invalidateQueries({ queryKey: ["connectors"] }); + setOpen(false); + } catch (err: any) { + setFormError(err.message ?? "Failed to save configuration"); + } + }); + + return ( + + + + + + +
+ +
+ Amazon S3 Setup +
+
+ + void} + accessKeySet={defaults?.access_key_set} + secretKeySet={defaults?.secret_key_set} + formError={formError} + /> + + + + + + +
+
+
+ ); +} diff --git a/frontend/app/settings/_components/s3-settings-form.tsx b/frontend/app/settings/_components/s3-settings-form.tsx new file mode 100644 index 000000000..22d36ac50 --- /dev/null +++ b/frontend/app/settings/_components/s3-settings-form.tsx @@ -0,0 +1,223 @@ +"use client"; + +import { useFormContext } from "react-hook-form"; +import { LabelWrapper } from "@/components/label-wrapper"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Loader2, RefreshCcw } from "lucide-react"; +import { Button } from "@/components/ui/button"; + +export interface S3FormData { + access_key: string; + secret_key: string; + endpoint_url: string; + region: string; +} + +interface S3SettingsFormProps { + /** Available buckets after a successful test — null means not yet tested */ + buckets: string[] | null; + selectedBuckets: string[]; + onSelectedBucketsChange: (buckets: string[]) => void; + isFetchingBuckets: boolean; + bucketsError: string | null; + onTestConnection: () => void; + accessKeySet?: boolean; + secretKeySet?: boolean; + formError?: string | null; +} + +export function S3SettingsForm({ + buckets, + selectedBuckets, + onSelectedBucketsChange, + isFetchingBuckets, + bucketsError, + onTestConnection, + accessKeySet, + secretKeySet, + formError, +}: S3SettingsFormProps) { + const { register } = useFormContext(); + + const toggleBucket = (name: string, checked: boolean) => { + if (checked) { + onSelectedBucketsChange([...selectedBuckets, name]); + } else { + onSelectedBucketsChange(selectedBuckets.filter((b) => b !== name)); + } + }; + + const toggleAll = (checked: boolean) => { + onSelectedBucketsChange(checked ? (buckets ?? []) : []); + }; + + return ( +
+ {/* Access Key ID */} +
+ + v?.trim() })} + id="s3-access-key" + type="password" + placeholder={ + accessKeySet + ? "•••••••• (loaded from env)" + : "AKIAIOSFODNN7EXAMPLE" + } + autoComplete="off" + /> + +
+ + {/* Secret Access Key */} +
+ + v?.trim() })} + id="s3-secret-key" + type="password" + placeholder={ + secretKeySet + ? "•••••••• (loaded from env)" + : "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + } + autoComplete="off" + /> + +
+ + {/* Endpoint URL (optional) */} +
+ + v?.trim() })} + id="s3-endpoint" + placeholder="https://your-minio.example.com" + autoComplete="off" + /> + +
+ + {/* Region (optional) */} +
+ + v?.trim() })} + id="s3-region" + placeholder="us-east-1" + autoComplete="off" + /> + +
+ + {/* Test connection */} + + + {bucketsError && ( +

+ {bucketsError} +

+ )} + + {formError && ( +

+ {formError} +

+ )} + + {/* Bucket selector */} + {buckets !== null && ( +
+
+ + {buckets.length > 1 && ( + + )} +
+ + {buckets.length === 0 ? ( +

+ No buckets found for this account. +

+ ) : ( +
+ {buckets.map((bucket) => ( + + ))} +
+ )} +
+ )} +
+ ); +} diff --git a/frontend/app/settings/_components/watsonx-settings-dialog.tsx b/frontend/app/settings/_components/watsonx-settings-dialog.tsx index d6b1d9334..36ed50cc8 100644 --- a/frontend/app/settings/_components/watsonx-settings-dialog.tsx +++ b/frontend/app/settings/_components/watsonx-settings-dialog.tsx @@ -17,11 +17,11 @@ import { DialogTitle, } from "@/components/ui/dialog"; import { useAuth } from "@/contexts/auth-context"; +import ModelProviderDialogFooter from "./model-provider-dialog-footer"; import { WatsonxSettingsForm, type WatsonxSettingsFormData, } from "./watsonx-settings-form"; -import ModelProviderDialogFooter from "./model-provider-dialog-footer"; const WatsonxSettingsDialog = ({ open, @@ -147,7 +147,13 @@ const WatsonxSettingsDialog = ({ }; return ( - { setShowRemoveConfirm(false); setOpen(o); }}> + { + setShowRemoveConfirm(false); + setOpen(o); + }} + >
diff --git a/frontend/app/upload/[provider]/page.tsx b/frontend/app/upload/[provider]/page.tsx index bf96214f7..832d313c7 100644 --- a/frontend/app/upload/[provider]/page.tsx +++ b/frontend/app/upload/[provider]/page.tsx @@ -1,13 +1,18 @@ "use client"; -import { AlertCircle, ArrowLeft } from "lucide-react"; +import { useQueryClient } from "@tanstack/react-query"; +import { AlertCircle, ArrowLeft, FolderOpen, RefreshCw } from "lucide-react"; import { useParams, useRouter } from "next/navigation"; -import { useEffect, useState } from "react"; +import { useState } from "react"; +import { toast } from "sonner"; import { useSyncConnector } from "@/app/api/mutations/useSyncConnector"; import { useGetConnectorsQuery } from "@/app/api/queries/useGetConnectorsQuery"; import { useGetConnectorTokenQuery } from "@/app/api/queries/useGetConnectorTokenQuery"; +import { useIBMCOSBucketStatusQuery } from "@/app/api/queries/useIBMCOSBucketStatusQuery"; +import { useS3BucketStatusQuery } from "@/app/api/queries/useS3BucketStatusQuery"; import { type CloudFile, UnifiedCloudPicker } from "@/components/cloud-picker"; -import type { IngestSettings } from "@/components/cloud-picker/types"; +import { IngestSettings } from "@/components/cloud-picker/ingest-settings"; +import type { IngestSettings as IngestSettingsType } from "@/components/cloud-picker/types"; import { Button } from "@/components/ui/button"; import { Tooltip, @@ -16,13 +21,328 @@ import { } from "@/components/ui/tooltip"; import { useTask } from "@/contexts/task-context"; +// Connectors that sync entire buckets/repositories without a file picker +const DIRECT_SYNC_PROVIDERS = ["ibm_cos", "aws_s3"]; + +// --------------------------------------------------------------------------- +// Shared bucket view — used by both IBM COS and S3 +// --------------------------------------------------------------------------- + +function BucketView({ + connector, + buckets, + isLoading, + bucketsError, + onRefetch, + invalidateQueryKey, + syncMutation, + addTask, + onBack, + onDone, +}: { + connector: any; + buckets: Array<{ name: string; ingested_count: number }> | undefined; + isLoading: boolean; + bucketsError?: Error | null; + onRefetch: () => void; + invalidateQueryKey: readonly unknown[]; + syncMutation: ReturnType; + addTask: (id: string) => void; + onBack: () => void; + onDone: () => void; +}) { + const queryClient = useQueryClient(); + const [selectedBuckets, setSelectedBuckets] = useState>( + new Set(), + ); + const [ingestSettings, setIngestSettings] = useState({ + chunkSize: 1000, + chunkOverlap: 200, + ocr: false, + pictureDescriptions: false, + embeddingModel: "text-embedding-3-small", + }); + const [isSettingsOpen, setIsSettingsOpen] = useState(false); + + const invalidate = () => { + queryClient.invalidateQueries({ queryKey: invalidateQueryKey }); + }; + + const toggleBucket = (bucketName: string) => { + setSelectedBuckets((prev) => { + const next = new Set(prev); + if (next.has(bucketName)) { + next.delete(bucketName); + } else { + next.add(bucketName); + } + return next; + }); + }; + + const ingestSelected = () => { + syncMutation.mutate( + { + connectorType: connector.type, + body: { + connection_id: connector.connectionId!, + selected_files: [], + bucket_filter: Array.from(selectedBuckets), + settings: ingestSettings, + }, + }, + { + onSuccess: (result) => { + invalidate(); + if (result.task_ids?.length) { + addTask(result.task_ids[0]); + onDone(); + } else { + toast.info("No files found in the selected buckets."); + } + }, + onError: (err) => { + toast.error(err instanceof Error ? err.message : "Sync failed"); + }, + }, + ); + }; + + return ( + <> +
+ +

+ Add from {connector.name} +

+
+ +
+
+

+ Select buckets to ingest. +

+
+ {selectedBuckets.size > 0 && ( + + )} + + +
+
+ + {isLoading ? ( +
+
+
+ ) : bucketsError ? ( +
+ {bucketsError.message || + "Failed to load buckets. Check your credentials and endpoint."} +
+ ) : !buckets?.length ? ( +
+ No buckets found. Check your credentials and endpoint. +
+ ) : ( +
+ {buckets.map((bucket) => { + const isSelected = selectedBuckets.has(bucket.name); + return ( +
toggleBucket(bucket.name)} + > +
+ {isSelected && ( + + + + )} +
+
+
+ +
+
+

+ {bucket.name} +

+ {bucket.ingested_count > 0 && ( +

+ {bucket.ingested_count} document + {bucket.ingested_count !== 1 ? "s" : ""} ingested +

+ )} +
+
+
+ ); + })} +
+ )} + + +
+ +
+
+ + +
+
+ + ); +} + +// --------------------------------------------------------------------------- +// IBM COS wrapper +// --------------------------------------------------------------------------- + +function IBMCOSBucketView({ + connector, + syncMutation, + addTask, + onBack, + onDone, +}: { + connector: any; + syncMutation: ReturnType; + addTask: (id: string) => void; + onBack: () => void; + onDone: () => void; +}) { + const { + data: buckets, + isLoading, + refetch, + } = useIBMCOSBucketStatusQuery(connector.connectionId, { enabled: true }); + return ( + + ); +} + +// --------------------------------------------------------------------------- +// Amazon S3 wrapper +// --------------------------------------------------------------------------- + +function S3BucketView({ + connector, + syncMutation, + addTask, + onBack, + onDone, +}: { + connector: any; + syncMutation: ReturnType; + addTask: (id: string) => void; + onBack: () => void; + onDone: () => void; +}) { + const { + data: buckets, + isLoading, + error: bucketsError, + refetch, + } = useS3BucketStatusQuery(connector.connectionId, { enabled: true }); + return ( + + ); +} + // CloudFile interface is now imported from the unified cloud picker export default function UploadProviderPage() { const params = useParams(); const router = useRouter(); const provider = params.provider as string; - const { addTask, tasks } = useTask(); + const { addTask } = useTask(); const { data: connectors = [], @@ -31,6 +351,8 @@ export default function UploadProviderPage() { } = useGetConnectorsQuery(); const connector = connectors.find((c) => c.type === provider); + const isDirectSyncProvider = DIRECT_SYNC_PROVIDERS.includes(provider); + const { data: tokenData, isLoading: tokenLoading } = useGetConnectorTokenQuery( { @@ -42,17 +364,18 @@ export default function UploadProviderPage() { : undefined, }, { - enabled: !!connector && connector.status === "connected", + // Direct-sync providers (e.g. IBM COS) don't use OAuth tokens + enabled: + !!connector && + connector.status === "connected" && + !isDirectSyncProvider, }, ); const syncMutation = useSyncConnector(); const [selectedFiles, setSelectedFiles] = useState([]); - const [currentSyncTaskId, setCurrentSyncTaskId] = useState( - null, - ); - const [ingestSettings, setIngestSettings] = useState({ + const [ingestSettings, setIngestSettings] = useState({ chunkSize: 1000, chunkOverlap: 200, ocr: false, @@ -61,7 +384,8 @@ export default function UploadProviderPage() { }); const accessToken = tokenData?.access_token || null; - const isLoading = connectorsLoading || tokenLoading; + const isLoading = + connectorsLoading || (!isDirectSyncProvider && tokenLoading); const isIngesting = syncMutation.isPending; // Error handling @@ -101,7 +425,6 @@ export default function UploadProviderPage() { if (taskIds && taskIds.length > 0) { const taskId = taskIds[0]; // Use the first task ID addTask(taskId); - setCurrentSyncTaskId(taskId); // Redirect to knowledge page already to show the syncing document router.push("/knowledge"); } @@ -193,6 +516,30 @@ export default function UploadProviderPage() { ); } + // Direct-sync providers show a bucket list with sync status. + if (isDirectSyncProvider && connector.status === "connected") { + if (provider === "aws_s3") { + return ( + router.back()} + onDone={() => router.push("/knowledge")} + /> + ); + } + return ( + router.back()} + onDone={() => router.push("/knowledge")} + /> + ); + } + if (!accessToken) { return ( <> diff --git a/frontend/components/cloud-picker/file-item.tsx b/frontend/components/cloud-picker/file-item.tsx index 617e54433..b71792685 100644 --- a/frontend/components/cloud-picker/file-item.tsx +++ b/frontend/components/cloud-picker/file-item.tsx @@ -2,8 +2,10 @@ import { FileText, Folder, Trash2 } from "lucide-react"; import GoogleDriveIcon from "@/components/icons/google-drive-logo"; +import IBMCOSIcon from "@/components/icons/ibm-cos-icon"; import OneDriveIcon from "@/components/icons/one-drive-logo"; import SharePointIcon from "@/components/icons/share-point-logo"; +import AwsLogo from "@/components/icons/aws-logo"; import { Button } from "@/components/ui/button"; import type { CloudFile } from "./types"; @@ -54,6 +56,10 @@ const getProviderIcon = (provider: string) => { return ; case "sharepoint": return ; + case "ibm_cos": + return ; + case "aws_s3": + return ; default: return ; } diff --git a/frontend/components/icons/ibm-cos-icon.tsx b/frontend/components/icons/ibm-cos-icon.tsx new file mode 100644 index 000000000..1637afbf3 --- /dev/null +++ b/frontend/components/icons/ibm-cos-icon.tsx @@ -0,0 +1,64 @@ +export default function IBMCOSIcon(props: React.SVGProps) { + return ( + + + + + + + + + + + + + + + + + + + + + + + ); +} diff --git a/frontend/components/knowledge-dropdown.tsx b/frontend/components/knowledge-dropdown.tsx index 1eba42fba..0895f9ba1 100644 --- a/frontend/components/knowledge-dropdown.tsx +++ b/frontend/components/knowledge-dropdown.tsx @@ -3,7 +3,6 @@ import { useQueryClient } from "@tanstack/react-query"; import { ChevronDown, - Cloud, File as FileIcon, Folder, FolderOpen, @@ -18,6 +17,7 @@ import { useGetTasksQuery } from "@/app/api/queries/useGetTasksQuery"; import { DuplicateHandlingDialog } from "@/components/duplicate-handling-dialog"; import AwsIcon from "@/components/icons/aws-logo"; import GoogleDriveIcon from "@/components/icons/google-drive-logo"; +import IBMCOSIcon from "@/components/icons/ibm-cos-icon"; import OneDriveIcon from "@/components/icons/one-drive-logo"; import SharePointIcon from "@/components/icons/share-point-logo"; import { Button } from "@/components/ui/button"; @@ -36,6 +36,7 @@ import { } from "@/components/ui/dropdown-menu"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { useAuth } from "@/contexts/auth-context"; import { useTask } from "@/contexts/task-context"; import { duplicateCheck, @@ -81,6 +82,7 @@ const FolderIconWithColor = ({ className }: { className?: string }) => ( ); export function KnowledgeDropdown() { + const { isIbmAuthMode } = useAuth(); const { addTask } = useTask(); const { refetch: refetchTasks } = useGetTasksQuery(); const queryClient = useQueryClient(); @@ -88,16 +90,14 @@ export function KnowledgeDropdown() { const [mounted, setMounted] = useState(false); const [isMenuOpen, setIsMenuOpen] = useState(false); const [showFolderDialog, setShowFolderDialog] = useState(false); - const [showS3Dialog, setShowS3Dialog] = useState(false); const [showDuplicateDialog, setShowDuplicateDialog] = useState(false); - const [awsEnabled, setAwsEnabled] = useState(false); const [uploadBatchSize, setUploadBatchSize] = useState(25); const [folderPath, setFolderPath] = useState(""); - const [bucketUrl, setBucketUrl] = useState("s3://"); const [folderLoading, setFolderLoading] = useState(false); - const [s3Loading, setS3Loading] = useState(false); const [fileUploading, setFileUploading] = useState(false); const [isNavigatingToCloud, setIsNavigatingToCloud] = useState(false); + const [ibmCosConfigured, setIbmCosConfigured] = useState(false); + const [s3Configured, setS3Configured] = useState(false); const [pendingFile, setPendingFile] = useState(null); const [duplicateFilename, setDuplicateFilename] = useState(""); const [cloudConnectors, setCloudConnectors] = useState<{ @@ -115,19 +115,41 @@ export function KnowledgeDropdown() { useEffect(() => { const checkAvailability = async () => { try { - // Check AWS and upload batch size - const awsRes = await fetch("/api/upload_options"); - if (awsRes.ok) { - const awsData = await awsRes.json(); - setAwsEnabled(Boolean(awsData.aws)); + // Check upload batch size and bucket connector availability in parallel + const [uploadOptionsRes, ibmCosRes, s3Res] = await Promise.all([ + fetch("/api/upload_options"), + fetch("/api/connectors/ibm_cos/defaults"), + fetch("/api/connectors/aws_s3/defaults"), + ]); + + if (uploadOptionsRes.ok) { + const uploadOptionsData = await uploadOptionsRes.json(); if ( - typeof awsData.upload_batch_size === "number" && - awsData.upload_batch_size > 0 + typeof uploadOptionsData.upload_batch_size === "number" && + uploadOptionsData.upload_batch_size > 0 ) { - setUploadBatchSize(awsData.upload_batch_size); + setUploadBatchSize(uploadOptionsData.upload_batch_size); } } + if (ibmCosRes.ok) { + const ibmCosData = await ibmCosRes.json(); + setIbmCosConfigured( + Boolean( + ibmCosData.connection_id || + ibmCosData.api_key_set || + ibmCosData.hmac_access_key_set, + ), + ); + } + + if (s3Res.ok) { + const s3Data = await s3Res.json(); + setS3Configured( + Boolean(s3Data.connection_id || s3Data.access_key_set), + ); + } + // Check cloud connectors const connectorsRes = await fetch("/api/connectors"); if (connectorsRes.ok) { @@ -461,49 +483,6 @@ export function KnowledgeDropdown() { } }; - const handleS3Upload = async () => { - if (!bucketUrl.trim()) return; - - setS3Loading(true); - setShowS3Dialog(false); - - try { - const response = await fetch("/api/upload_bucket", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ s3_url: bucketUrl }), - }); - - const result = await response.json(); - - if (response.status === 201) { - const taskId = result.task_id || result.id; - - if (!taskId) { - throw new Error("No task ID received from server"); - } - - addTask(taskId); - setBucketUrl("s3://"); - // Refetch tasks to show the new task - refetchTasks(); - } else { - console.error("S3 upload failed:", result.error); - if (response.status === 400) { - toast.error("Upload failed", { - description: result.error || "Bad request", - }); - } - } - } catch (error) { - console.error("S3 upload error:", error); - } finally { - setS3Loading(false); - } - }; - // Icon mapping for cloud connectors const connectorIconMap = { google_drive: GoogleDriveIcon, @@ -544,12 +523,21 @@ export function KnowledgeDropdown() { icon: FolderIconWithColor, onClick: () => folderInputRef.current?.click(), }, - ...(awsEnabled + ...(isIbmAuthMode && s3Configured ? [ { label: "Amazon S3", icon: AwsIcon, - onClick: () => setShowS3Dialog(true), + onClick: () => router.push("/upload/aws_s3"), + }, + ] + : []), + ...(isIbmAuthMode && ibmCosConfigured + ? [ + { + label: "IBM Cloud Object Storage", + icon: IBMCOSIcon, + onClick: () => router.push("/upload/ibm_cos"), }, ] : []), @@ -557,8 +545,7 @@ export function KnowledgeDropdown() { ]; // Comprehensive loading state - const isLoading = - fileUploading || folderLoading || s3Loading || isNavigatingToCloud; + const isLoading = fileUploading || folderLoading || isNavigatingToCloud; if (!mounted) { return ( @@ -581,11 +568,9 @@ export function KnowledgeDropdown() { ? "Uploading..." : folderLoading ? "Processing Folder..." - : s3Loading - ? "Processing S3..." - : isNavigatingToCloud - ? "Loading..." - : "Processing..." + : isNavigatingToCloud + ? "Loading..." + : "Processing..." : "Add Knowledge"} {!isLoading && ( @@ -673,45 +658,6 @@ export function KnowledgeDropdown() {
- {/* Process S3 Bucket Dialog */} - - - - - - Process S3 Bucket - - - Process all documents from an S3 bucket. AWS credentials must be - configured. - - -
-
- - setBucketUrl(e.target.value)} - /> -
-
- - -
-
-
-
- {/* Duplicate Handling Dialog */} void; logout: () => Promise; refreshAuth: () => Promise; @@ -46,6 +47,7 @@ export function AuthProvider({ children }: AuthProviderProps) { const [user, setUser] = useState(null); const [isLoading, setIsLoading] = useState(true); const [isNoAuthMode, setIsNoAuthMode] = useState(false); + const [isIbmAuthMode, setIsIbmAuthMode] = useState(false); const checkAuth = useCallback(async () => { try { @@ -60,6 +62,8 @@ export function AuthProvider({ children }: AuthProviderProps) { const data = await response.json(); + setIsIbmAuthMode(!!data.ibm_auth_mode); + // Check if we're in no-auth mode if (data.no_auth_mode) { setIsNoAuthMode(true); @@ -176,6 +180,7 @@ export function AuthProvider({ children }: AuthProviderProps) { isLoading, isAuthenticated: !!user, isNoAuthMode, + isIbmAuthMode, login, logout, refreshAuth, diff --git a/pyproject.toml b/pyproject.toml index 261c5439c..d6a869206 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "fastapi>=0.115.0", "uvicorn>=0.35.0", "boto3>=1.35.0", + "ibm-cos-sdk>=2.13.0", "psutil>=7.0.0", "rich>=13.0.0", "textual>=0.45.0", diff --git a/sdks/python/openrag_sdk/chat.py b/sdks/python/openrag_sdk/chat.py index ef33bc243..749ca54ae 100644 --- a/sdks/python/openrag_sdk/chat.py +++ b/sdks/python/openrag_sdk/chat.py @@ -490,11 +490,16 @@ async def delete(self, chat_id: str) -> bool: chat_id: The ID of the conversation to delete. Returns: - True if deletion was successful. + True if deletion was successful, False if the conversation was not found. """ - response = await self._client._request("DELETE", f"/api/v1/chat/{chat_id}") - data = response.json() - return data.get("success", False) + from .exceptions import NotFoundError + + try: + response = await self._client._request("DELETE", f"/api/v1/chat/{chat_id}") + data = response.json() + return data.get("success", False) + except NotFoundError: + return False # Import Literal for type hints diff --git a/sdks/python/tests/test_integration.py b/sdks/python/tests/test_integration.py deleted file mode 100644 index c1a202ba5..000000000 --- a/sdks/python/tests/test_integration.py +++ /dev/null @@ -1,467 +0,0 @@ -""" -Integration tests for OpenRAG Python SDK. - -These tests run against a real OpenRAG instance. -Requires: OPENRAG_URL environment variable (defaults to http://localhost:3000) - -Run with: pytest sdks/python/tests/test_integration.py -v -""" - -import os -from pathlib import Path - -import httpx -import pytest - -# Skip all tests if no OpenRAG instance is available -pytestmark = pytest.mark.skipif( - os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", - reason="SDK integration tests skipped", -) - -# Module-level cache for API key (created once, reused) -_cached_api_key: str | None = None -_base_url = os.environ.get("OPENRAG_URL", "http://localhost:3000") -_onboarding_done = False - - -@pytest.fixture(scope="session", autouse=True) -def ensure_onboarding(): - """Ensure the OpenRAG instance is onboarded before running tests. - - This marks the config as 'edited' so that settings updates are allowed. - """ - global _onboarding_done - if _onboarding_done: - return - - onboarding_payload = { - "llm_provider": "openai", - "embedding_provider": "openai", - "embedding_model": "text-embedding-3-small", - "llm_model": "gpt-4o-mini", - } - - try: - response = httpx.post( - f"{_base_url}/api/onboarding", - json=onboarding_payload, - timeout=30.0, - ) - if response.status_code in (200, 204): - print(f"[SDK Tests] Onboarding completed successfully") - else: - # May already be onboarded, which is fine - print(f"[SDK Tests] Onboarding returned {response.status_code}: {response.text[:200]}") - except Exception as e: - print(f"[SDK Tests] Onboarding request failed: {e}") - - _onboarding_done = True - - -def get_api_key() -> str: - """Get or create an API key for testing.""" - global _cached_api_key - if _cached_api_key is None: - # Use /api/keys to go through frontend proxy (frontend at :3000 proxies /api/* to backend) - response = httpx.post( - f"{_base_url}/api/keys", - json={"name": "SDK Integration Test"}, - timeout=30.0, - ) - if response.status_code == 401: - pytest.skip("Cannot create API key - authentication required") - assert response.status_code == 200, f"Failed to create API key: {response.text}" - _cached_api_key = response.json()["api_key"] - return _cached_api_key - - -@pytest.fixture -def client(): - """Create an OpenRAG client for each test.""" - from openrag_sdk import OpenRAGClient - - return OpenRAGClient(api_key=get_api_key(), base_url=_base_url) - - -@pytest.fixture -def test_file(tmp_path) -> Path: - """Create a test file for ingestion with unique content.""" - import uuid - # Use .md extension - Langflow has issues with .txt files - file_path = tmp_path / f"sdk_test_doc_{uuid.uuid4().hex[:8]}.md" - file_path.write_text( - f"# SDK Integration Test Document\n\n" - f"ID: {uuid.uuid4()}\n\n" - "This document tests the OpenRAG Python SDK.\n\n" - "It contains unique content about purple elephants dancing.\n" - ) - return file_path - - -class TestSettings: - """Test settings endpoint.""" - - @pytest.mark.asyncio - async def test_get_settings(self, client): - """Test getting settings.""" - settings = await client.settings.get() - - assert settings.agent is not None - assert settings.knowledge is not None - - @pytest.mark.asyncio - async def test_update_settings(self, client): - """Test updating settings.""" - # Get current settings first - current_settings = await client.settings.get() - current_chunk_size = current_settings.knowledge.chunk_size or 1000 - - # Update with the same value (safe for tests) - result = await client.settings.update({"chunk_size": current_chunk_size}) - - assert result.message is not None - - # Verify the setting persisted - updated_settings = await client.settings.get() - assert updated_settings.knowledge.chunk_size == current_chunk_size - - -class TestModels: - """Test models endpoint.""" - - @pytest.mark.asyncio - async def test_list_models(self, client): - """Test listing models for a provider.""" - # This tests both the API key auth and the minted JWT - # since models_service often needs credentials/JWT - models = await client.models.list("openai") - - assert models.language_models is not None - assert isinstance(models.language_models, list) - assert models.embedding_models is not None - assert isinstance(models.embedding_models, list) - - -class TestKnowledgeFilters: - """Test knowledge filter operations.""" - - @pytest.mark.asyncio - async def test_knowledge_filter_crud(self, client): - """Test create, read, update, delete for knowledge filters.""" - # Create - create_result = await client.knowledge_filters.create({ - "name": "Python SDK Test Filter", - "description": "Filter created by Python SDK integration tests", - "queryData": { - "query": "test documents", - "limit": 10, - "scoreThreshold": 0.5, - }, - }) - - assert create_result.success is True - assert create_result.id is not None - filter_id = create_result.id - - # Search - filters = await client.knowledge_filters.search("Python SDK Test") - assert isinstance(filters, list) - found = any(f.name == "Python SDK Test Filter" for f in filters) - assert found is True - - # Get - filter_obj = await client.knowledge_filters.get(filter_id) - assert filter_obj is not None - assert filter_obj.id == filter_id - assert filter_obj.name == "Python SDK Test Filter" - - # Update - update_success = await client.knowledge_filters.update( - filter_id, - {"description": "Updated description from Python SDK test"}, - ) - assert update_success is True - - # Verify update - updated_filter = await client.knowledge_filters.get(filter_id) - assert updated_filter.description == "Updated description from Python SDK test" - - # Delete - delete_success = await client.knowledge_filters.delete(filter_id) - assert delete_success is True - - # Verify deletion - deleted_filter = await client.knowledge_filters.get(filter_id) - assert deleted_filter is None - - @pytest.mark.asyncio - async def test_filter_id_in_chat(self, client): - """Test using filter_id in chat.""" - # Create a filter first - create_result = await client.knowledge_filters.create({ - "name": "Chat Test Filter Python", - "description": "Filter for testing chat with filter_id", - "queryData": { - "query": "test", - "limit": 5, - }, - }) - assert create_result.success is True - filter_id = create_result.id - - try: - # Use filter in chat - response = await client.chat.create( - message="Hello with filter", - filter_id=filter_id, - ) - assert response.response is not None - finally: - # Cleanup - await client.knowledge_filters.delete(filter_id) - - @pytest.mark.asyncio - async def test_filter_id_in_search(self, client): - """Test using filter_id in search.""" - # Create a filter first - create_result = await client.knowledge_filters.create({ - "name": "Search Test Filter Python", - "description": "Filter for testing search with filter_id", - "queryData": { - "query": "test", - "limit": 5, - }, - }) - assert create_result.success is True - filter_id = create_result.id - - try: - # Use filter in search - results = await client.search.query("test query", filter_id=filter_id) - assert results.results is not None - finally: - # Cleanup - await client.knowledge_filters.delete(filter_id) - - -class TestDocuments: - """Test document operations.""" - - @pytest.mark.asyncio - async def test_ingest_document_no_wait(self, client, test_file: Path): - """Test document ingestion without waiting.""" - # wait=False returns immediately with task_id - result = await client.documents.ingest(file_path=str(test_file), wait=False) - - assert result.task_id is not None - - # Can poll manually - final_status = await client.documents.wait_for_task(result.task_id) - # TODO: Fix Langflow ingestion - status may be "failed" due to flow issues - assert final_status.status is not None - assert final_status.successful_files >= 0 - - @pytest.mark.asyncio - async def test_ingest_document(self, client, test_file: Path): - """Test document ingestion.""" - # wait=True (default) polls until completion - result = await client.documents.ingest(file_path=str(test_file)) - - # TODO: Fix Langflow ingestion - status may be "failed" due to flow issues - assert result.status is not None - assert result.successful_files >= 0 - - - @pytest.mark.asyncio - async def test_delete_document(self, client, test_file: Path): - """Test document deletion.""" - # First ingest (wait for completion) - ingest_result = await client.documents.ingest(file_path=str(test_file)) - - # Then delete - result = await client.documents.delete(test_file.name) - - # If ingestion produced indexed chunks, delete should succeed. - # In unstable flow environments, ingestion can complete with zero successful files. - if ingest_result.successful_files > 0: - assert result.success is True - assert result.deleted_chunks > 0 - else: - assert result.success is False - assert result.deleted_chunks == 0 - - @pytest.mark.asyncio - async def test_delete_missing_document_is_idempotent(self, client): - """Deleting a never-ingested filename should not raise.""" - import uuid - - missing_filename = f"never_ingested_{uuid.uuid4().hex}.pdf" - result = await client.documents.delete(missing_filename) - - assert result.success is False - assert result.deleted_chunks == 0 - assert result.filename == missing_filename - assert result.error is not None - - -class TestSearch: - """Test search operations.""" - - @pytest.mark.asyncio - async def test_search_query(self, client, test_file: Path): - """Test search query.""" - # Ensure document is ingested - await client.documents.ingest(file_path=str(test_file)) - - # Wait a bit for indexing - import asyncio - await asyncio.sleep(2) - - # Search for unique content - results = await client.search.query("purple elephants dancing") - - assert results.results is not None - # Note: might be empty if indexing is slow, that's ok for CI - - -class TestChat: - """Test chat operations.""" - - @pytest.mark.asyncio - async def test_chat_non_streaming(self, client): - """Test non-streaming chat.""" - response = await client.chat.create( - message="Say hello in exactly 3 words." - ) - - assert response.response is not None - assert isinstance(response.response, str) - assert len(response.response) > 0 - - @pytest.mark.asyncio - async def test_chat_streaming_create(self, client): - """Test streaming chat with create(stream=True).""" - collected_text = "" - - async for event in await client.chat.create( - message="Say 'test' and nothing else.", - stream=True, - ): - if event.type == "content": - collected_text += event.delta - - assert len(collected_text) > 0 - - @pytest.mark.asyncio - async def test_chat_streaming_context_manager(self, client): - """Test streaming chat with stream() context manager.""" - async with client.chat.stream( - message="Say 'hello' and nothing else." - ) as stream: - async for _ in stream: - pass - - # Check aggregated properties - assert len(stream.text) > 0 - - @pytest.mark.asyncio - async def test_chat_text_stream(self, client): - """Test text_stream helper.""" - collected = "" - - async with client.chat.stream( - message="Say 'world' and nothing else." - ) as stream: - async for text in stream.text_stream: - collected += text - - assert len(collected) > 0 - - @pytest.mark.asyncio - async def test_chat_final_text(self, client): - """Test final_text() helper.""" - async with client.chat.stream( - message="Say 'done' and nothing else." - ) as stream: - text = await stream.final_text() - - assert len(text) > 0 - - @pytest.mark.asyncio - async def test_chat_conversation_continuation(self, client): - """Test continuing a conversation.""" - # First message - response1 = await client.chat.create( - message="Remember the number 42." - ) - assert response1.chat_id is not None - - # Continue conversation - response2 = await client.chat.create( - message="What number did I ask you to remember?", - chat_id=response1.chat_id, - ) - assert response2.response is not None - - @pytest.mark.asyncio - async def test_list_conversations(self, client): - """Test listing conversations.""" - # Create a conversation first - await client.chat.create(message="Test message for listing.") - - # List conversations - result = await client.chat.list() - - assert result.conversations is not None - assert isinstance(result.conversations, list) - - @pytest.mark.asyncio - async def test_get_conversation(self, client): - """Test getting a specific conversation.""" - # Create a conversation first - response = await client.chat.create(message="Test message for get.") - assert response.chat_id is not None - - # Get the conversation - conversation = await client.chat.get(response.chat_id) - - assert conversation.chat_id == response.chat_id - assert conversation.messages is not None - assert isinstance(conversation.messages, list) - assert len(conversation.messages) >= 1 - - @pytest.mark.asyncio - async def test_delete_conversation(self, client): - """Test deleting a conversation.""" - # Create a conversation first - response = await client.chat.create(message="Test message for delete.") - assert response.chat_id is not None - - # Delete the conversation - result = await client.chat.delete(response.chat_id) - - assert result is True - - @pytest.mark.asyncio - async def test_chat_with_sources(self, client, test_file: Path): - """Test chat uses embedded knowledge (RAG), not just pure LLM.""" - # 1. Ingest document - result = await client.documents.ingest(file_path=str(test_file)) - if result.status == "failed" or result.successful_files == 0: - pytest.skip("Document ingestion failed — cannot test RAG sources") - - # 2. Wait for indexing - import asyncio - await asyncio.sleep(3) - - # 3. Chat about document content - response = await client.chat.create( - message="What is the color of the dancing animals mentioned in my documents?" - ) - - # 4. Verify sources — proves RAG retrieval worked - assert response.sources is not None - assert len(response.sources) > 0 - source_filenames = [s.filename for s in response.sources] - assert any(test_file.name in name for name in source_filenames) diff --git a/src/agent.py b/src/agent.py index 172bf23df..a59fe84f6 100644 --- a/src/agent.py +++ b/src/agent.py @@ -815,31 +815,36 @@ async def async_langflow_chat_stream( async def delete_user_conversation(user_id: str, response_id: str) -> bool: - """Delete a conversation for a user from both memory and persistent storage (async, non-blocking)""" + """Delete a conversation for a user from both memory and persistent storage. + + Returns: + True — conversation was found and deleted from at least one store. + False — conversation did not exist in any store (confirmed not-found). + + Raises: + Exception — on unexpected storage errors so callers can distinguish + a confirmed "not found" from a backend failure. + """ deleted = False - try: - # Delete from in-memory storage - if user_id in active_conversations and response_id in active_conversations[user_id]: - del active_conversations[user_id][response_id] - logger.debug(f"Deleted conversation {response_id} from memory for user {user_id}") - deleted = True - - # Delete from persistent storage - conversation_deleted = await conversation_persistence.delete_conversation_thread(user_id, response_id) - if conversation_deleted: - logger.debug(f"Deleted conversation {response_id} from persistent storage for user {user_id}") - deleted = True - - # Release session ownership - try: - from services.session_ownership_service import session_ownership_service - session_ownership_service.release_session(user_id, response_id) - logger.debug(f"Released session ownership for {response_id} for user {user_id}") - except Exception as e: - logger.warning(f"Failed to release session ownership: {e}") + # Delete from in-memory storage (cannot raise) + if user_id in active_conversations and response_id in active_conversations[user_id]: + del active_conversations[user_id][response_id] + logger.debug(f"Deleted conversation {response_id} from memory for user {user_id}") + deleted = True + + # Delete from persistent storage — let real errors propagate to the caller + conversation_deleted = await conversation_persistence.delete_conversation_thread(user_id, response_id) + if conversation_deleted: + logger.debug(f"Deleted conversation {response_id} from persistent storage for user {user_id}") + deleted = True - return deleted + # Release session ownership (best-effort; never masks storage errors above) + try: + from services.session_ownership_service import session_ownership_service + session_ownership_service.release_session(user_id, response_id) + logger.debug(f"Released session ownership for {response_id} for user {user_id}") except Exception as e: - logger.error(f"Error deleting conversation {response_id} for user {user_id}: {e}") - return False + logger.warning(f"Failed to release session ownership: {e}") + + return deleted diff --git a/src/api/connectors.py b/src/api/connectors.py index 73c096d44..ea2e50eaa 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -88,6 +88,11 @@ async def get_synced_file_ids_for_connector( class ConnectorSyncBody(BaseModel): max_files: Optional[int] = None selected_files: Optional[List[Any]] = None + # When True, ingest ALL files from the connector (bypasses the existing-files gate). + # Used by direct-sync providers like IBM COS on initial ingest. + sync_all: bool = False + # When set, only ingest files from these buckets (IBM COS specific). + bucket_filter: Optional[List[str]] = None async def list_connectors( @@ -96,8 +101,8 @@ async def list_connectors( ): """List available connector types with metadata""" try: - connector_types = ( - connector_service.connection_manager.get_available_connector_types() + connector_types = connector_service.connection_manager.get_available_connector_types( + user_id=user.user_id ) return JSONResponse({"connectors": connector_types}) except Exception as e: @@ -200,6 +205,51 @@ async def connector_sync( jwt_token=jwt_token, file_infos=file_infos, ) + elif body.sync_all or body.bucket_filter: + # Full ingest: discover and ingest all files (or files from specific buckets). + # Used by direct-sync providers (IBM COS) on initial ingest or per-bucket sync. + logger.info( + "Full connector ingest requested", + connector_type=connector_type, + bucket_filter=body.bucket_filter, + ) + connector = await connector_service.get_connector(working_connection.connection_id) + if body.bucket_filter: + # List only files from the requested buckets, then sync_specific_files + original_buckets = connector.bucket_names + connector.bucket_names = body.bucket_filter + try: + all_file_ids = [] + page_token = None + while True: + result = await connector.list_files(page_token=page_token) + for f in result.get("files", []): + all_file_ids.append(f["id"]) + page_token = result.get("next_page_token") + if not page_token: + break + finally: + connector.bucket_names = original_buckets + + if not all_file_ids: + return JSONResponse( + {"status": "no_files", "message": "No files found in the selected buckets."}, + status_code=200, + ) + task_id = await connector_service.sync_specific_files( + working_connection.connection_id, + user.user_id, + all_file_ids, + jwt_token=jwt_token, + ) + else: + # sync_all: ingest everything the connector can see + task_id = await connector_service.sync_connector_files( + working_connection.connection_id, + user.user_id, + max_files=max_files, + jwt_token=jwt_token, + ) else: # No files specified - sync only files already in OpenSearch for this connector # This ensures deleted files stay deleted @@ -209,7 +259,7 @@ async def connector_sync( session_manager=session_manager, jwt_token=jwt_token, ) - + if not existing_file_ids and not existing_filenames: return JSONResponse( { @@ -218,7 +268,7 @@ async def connector_sync( }, status_code=200, ) - + # If we have document_ids (connector file IDs), use sync_specific_files # Otherwise, use filename filtering with sync_connector_files if existing_file_ids: @@ -602,6 +652,8 @@ async def connector_disconnect( ) +# --------------------------------------------------------------------------- + async def sync_all_connectors( connector_service=Depends(get_connector_service), session_manager=Depends(get_session_manager), @@ -615,7 +667,7 @@ async def sync_all_connectors( jwt_token = user.jwt_token # Cloud connector types to sync - cloud_connector_types = ["google_drive", "onedrive", "sharepoint"] + cloud_connector_types = ["google_drive", "onedrive", "sharepoint", "ibm_cos", "aws_s3"] all_task_ids = [] synced_connectors = [] diff --git a/src/api/v1/chat.py b/src/api/v1/chat.py index b87507afe..1a27804cd 100644 --- a/src/api/v1/chat.py +++ b/src/api/v1/chat.py @@ -226,6 +226,8 @@ async def chat_delete_endpoint( """Delete a conversation. DELETE /v1/chat/{chat_id}""" try: result = await chat_service.delete_session(user.user_id, chat_id) + if result.get("not_found"): + return JSONResponse({"error": "Conversation not found"}, status_code=404) if result.get("success"): return JSONResponse({"success": True}) else: diff --git a/src/config/settings.py b/src/config/settings.py index b1df97074..b31fe2679 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -51,6 +51,8 @@ GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") DOCLING_OCR_ENGINE = os.getenv("DOCLING_OCR_ENGINE") +IBM_AUTH_ENABLED = os.getenv("IBM_AUTH_ENABLED", "false").lower() in ("true", "1", "yes") + # Ingestion configuration DISABLE_INGEST_WITH_LANGFLOW = os.getenv( "DISABLE_INGEST_WITH_LANGFLOW", "false" diff --git a/src/connectors/__init__.py b/src/connectors/__init__.py index 83e77b8cd..323a76285 100644 --- a/src/connectors/__init__.py +++ b/src/connectors/__init__.py @@ -2,10 +2,14 @@ from .google_drive import GoogleDriveConnector from .sharepoint import SharePointConnector from .onedrive import OneDriveConnector +from .ibm_cos import IBMCOSConnector +from .aws_s3 import S3Connector __all__ = [ "BaseConnector", "GoogleDriveConnector", "SharePointConnector", "OneDriveConnector", + "IBMCOSConnector", + "S3Connector", ] diff --git a/src/connectors/aws_s3/__init__.py b/src/connectors/aws_s3/__init__.py new file mode 100644 index 000000000..32c63a878 --- /dev/null +++ b/src/connectors/aws_s3/__init__.py @@ -0,0 +1,19 @@ +"""Amazon S3 / S3-compatible connector for OpenRAG.""" + +from .connector import S3Connector +from .models import S3ConfigureBody +from .api import ( + s3_defaults, + s3_configure, + s3_list_buckets, + s3_bucket_status, +) + +__all__ = [ + "S3Connector", + "S3ConfigureBody", + "s3_defaults", + "s3_configure", + "s3_list_buckets", + "s3_bucket_status", +] diff --git a/src/connectors/aws_s3/api.py b/src/connectors/aws_s3/api.py new file mode 100644 index 000000000..178c78ab5 --- /dev/null +++ b/src/connectors/aws_s3/api.py @@ -0,0 +1,175 @@ +"""FastAPI route handlers for AWS S3-specific endpoints.""" + +import os + +from fastapi import Depends +from fastapi.responses import JSONResponse + +from config.settings import get_index_name +from dependencies import get_connector_service, get_session_manager, get_current_user +from session_manager import User +from utils.logging_config import get_logger + +from .auth import create_s3_resource +from .models import S3ConfigureBody +from .support import build_s3_config + +logger = get_logger(__name__) + + +async def s3_defaults( + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """Return current S3 env-var defaults for pre-filling the config dialog. + + Sensitive values (secret key) are masked — only whether they are set is returned. + """ + access_key = os.getenv("AWS_ACCESS_KEY_ID", "") + secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "") + endpoint_url = os.getenv("AWS_S3_ENDPOINT", "") + region = os.getenv("AWS_REGION", "") + + connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, connector_type="aws_s3" + ) + conn_config = connections[0].config or {} if connections else {} + + def _pick(conn_key, env_val): + return conn_config.get(conn_key) or env_val + + return JSONResponse({ + "access_key_set": bool(access_key or conn_config.get("access_key")), + "secret_key_set": bool(secret_key or conn_config.get("secret_key")), + "endpoint": _pick("endpoint_url", endpoint_url), + "region": _pick("region", region), + "bucket_names": conn_config.get("bucket_names", []), + "connection_id": connections[0].connection_id if connections else None, + }) + + +async def s3_configure( + body: S3ConfigureBody, + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """Create or update an S3 connection with explicit credentials. + + Tests the credentials by listing buckets, then persists the connection. + """ + existing_connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, connector_type="aws_s3" + ) + existing_config = existing_connections[0].config if existing_connections else {} + + conn_config, error = build_s3_config(body, existing_config) + if error: + return JSONResponse({"error": error}, status_code=400) + + # Test credentials + try: + s3 = create_s3_resource(conn_config) + list(s3.buckets.all()) + except Exception: + logger.exception("Failed to connect to S3 during credential test.") + return JSONResponse( + {"error": "Could not connect to S3 with the provided configuration."}, + status_code=400, + ) + + # Persist: update existing connection or create a new one + if body.connection_id: + existing = await connector_service.connection_manager.get_connection(body.connection_id) + if existing and existing.user_id == user.user_id: + await connector_service.connection_manager.update_connection( + connection_id=body.connection_id, + config=conn_config, + ) + connector_service.connection_manager.active_connectors.pop(body.connection_id, None) + return JSONResponse({"connection_id": body.connection_id, "status": "connected"}) + + connection_id = await connector_service.connection_manager.create_connection( + connector_type="aws_s3", + name="Amazon S3", + config=conn_config, + user_id=user.user_id, + ) + return JSONResponse({"connection_id": connection_id, "status": "connected"}) + + +async def s3_list_buckets( + connection_id: str, + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """List all buckets accessible with the stored S3 credentials.""" + connection = await connector_service.connection_manager.get_connection(connection_id) + if not connection or connection.user_id != user.user_id: + return JSONResponse({"error": "Connection not found"}, status_code=404) + if connection.connector_type != "aws_s3": + return JSONResponse({"error": "Not an S3 connection"}, status_code=400) + + try: + s3 = create_s3_resource(connection.config) + buckets = [b.name for b in s3.buckets.all()] + return JSONResponse({"buckets": buckets}) + except Exception: + logger.exception("Failed to list S3 buckets for connection %s", connection_id) + return JSONResponse({"error": "Failed to list buckets"}, status_code=500) + + +async def s3_bucket_status( + connection_id: str, + connector_service=Depends(get_connector_service), + session_manager=Depends(get_session_manager), + user: User = Depends(get_current_user), +): + """Return all buckets for an S3 connection with their ingestion status.""" + connection = await connector_service.connection_manager.get_connection(connection_id) + if not connection or connection.user_id != user.user_id: + return JSONResponse({"error": "Connection not found"}, status_code=404) + if connection.connector_type != "aws_s3": + return JSONResponse({"error": "Not an S3 connection"}, status_code=400) + + # 1. List all buckets from S3 + try: + s3 = create_s3_resource(connection.config) + all_buckets = [b.name for b in s3.buckets.all()] + except Exception as exc: + logger.exception("Failed to list buckets from S3 for connection %s", connection_id) + return JSONResponse({"error": "Failed to list buckets"}, status_code=500) + + # 2. Count indexed documents per bucket from OpenSearch + ingested_counts: dict = {} + try: + opensearch_client = session_manager.get_user_opensearch_client( + user.user_id, user.jwt_token + ) + query_body = { + "size": 0, + "query": {"term": {"connector_type": "aws_s3"}}, + "aggs": { + "doc_ids": { + "terms": {"field": "document_id", "size": 50000} + } + }, + } + index_name = get_index_name() + os_resp = opensearch_client.search(index=index_name, body=query_body) + for bucket_entry in os_resp.get("aggregations", {}).get("doc_ids", {}).get("buckets", []): + doc_id = bucket_entry["key"] + if "::" in doc_id: + bucket_name = doc_id.split("::")[0] + ingested_counts[bucket_name] = ingested_counts.get(bucket_name, 0) + 1 + except Exception: + pass # OpenSearch unavailable — show zero counts + + result = [ + { + "name": bucket, + "ingested_count": ingested_counts.get(bucket, 0), + "is_synced": ingested_counts.get(bucket, 0) > 0, + } + for bucket in all_buckets + ] + return JSONResponse({"buckets": result}) diff --git a/src/connectors/aws_s3/auth.py b/src/connectors/aws_s3/auth.py new file mode 100644 index 000000000..78894eca5 --- /dev/null +++ b/src/connectors/aws_s3/auth.py @@ -0,0 +1,90 @@ +"""Amazon S3 / S3-compatible storage authentication and client factory.""" + +import os +from typing import Any, Dict, Optional + +from utils.logging_config import get_logger + +logger = get_logger(__name__) + +_DEFAULT_REGION = "us-east-1" + + +def _resolve_credentials(config: Dict[str, Any]) -> Dict[str, Any]: + """Resolve S3 credentials from config dict with environment variable fallback. + + Resolution order for each value: config dict → environment variable → default. + + Raises: + ValueError: If access_key or secret_key cannot be resolved. + """ + access_key: Optional[str] = config.get("access_key") or os.getenv("AWS_ACCESS_KEY_ID") + secret_key: Optional[str] = config.get("secret_key") or os.getenv("AWS_SECRET_ACCESS_KEY") + + if not access_key or not secret_key: + raise ValueError( + "S3 credentials are required. Provide 'access_key' and 'secret_key' in the " + "connector config, or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars." + ) + + # endpoint_url is optional — only inject when non-empty (real AWS users don't set it) + endpoint_url: Optional[str] = config.get("endpoint_url") or os.getenv("AWS_S3_ENDPOINT") or None + + region: str = config.get("region") or os.getenv("AWS_REGION") or _DEFAULT_REGION + + return { + "access_key": access_key, + "secret_key": secret_key, + "endpoint_url": endpoint_url, + "region": region, + } + + +def _build_boto3_kwargs(creds: Dict[str, Any]) -> Dict[str, Any]: + """Build the keyword arguments for boto3.resource / boto3.client.""" + kwargs: Dict[str, Any] = { + "aws_access_key_id": creds["access_key"], + "aws_secret_access_key": creds["secret_key"], + "region_name": creds["region"], + } + if creds["endpoint_url"]: + kwargs["endpoint_url"] = creds["endpoint_url"] + return kwargs + + +def create_s3_resource(config: Dict[str, Any]): + """Return a boto3 S3 resource (high-level API) for bucket/object access. + + Works with AWS S3, MinIO, Cloudflare R2, and any S3-compatible service. + """ + try: + import boto3 + except ImportError as exc: + raise ImportError( + "boto3 is required for the S3 connector. " + "Install it with: pip install boto3" + ) from exc + + creds = _resolve_credentials(config) + kwargs = _build_boto3_kwargs(creds) + logger.debug("Creating S3 resource with HMAC authentication (boto3)") + return boto3.resource("s3", **kwargs) + + +def create_s3_client(config: Dict[str, Any]): + """Return a boto3 S3 low-level client. + + Used for operations such as list_buckets() and get_object_acl(). + """ + try: + import boto3 + except ImportError as exc: + raise ImportError( + "boto3 is required for the S3 connector. " + "Install it with: pip install boto3" + ) from exc + + creds = _resolve_credentials(config) + kwargs = _build_boto3_kwargs(creds) + logger.debug("Creating S3 client with HMAC authentication (boto3)") + return boto3.client("s3", **kwargs) diff --git a/src/connectors/aws_s3/connector.py b/src/connectors/aws_s3/connector.py new file mode 100644 index 000000000..13929e41b --- /dev/null +++ b/src/connectors/aws_s3/connector.py @@ -0,0 +1,277 @@ +"""Amazon S3 / S3-compatible storage connector for OpenRAG.""" + +import mimetypes +import os +from datetime import datetime, timezone +from posixpath import basename +from typing import Any, Dict, List, Optional + +from connectors.base import BaseConnector, ConnectorDocument, DocumentACL +from utils.logging_config import get_logger + +from .auth import create_s3_client, create_s3_resource + +logger = get_logger(__name__) + +# Separator used in composite file IDs: "::" +_ID_SEPARATOR = "::" + + +def _make_file_id(bucket: str, key: str) -> str: + return f"{bucket}{_ID_SEPARATOR}{key}" + + +def _split_file_id(file_id: str): + """Split a composite file ID into (bucket, key). Raises ValueError if invalid.""" + if _ID_SEPARATOR not in file_id: + raise ValueError(f"Invalid S3 file ID (missing separator): {file_id!r}") + bucket, key = file_id.split(_ID_SEPARATOR, 1) + return bucket, key + + +class S3Connector(BaseConnector): + """Connector for Amazon S3 and S3-compatible object storage. + + Uses HMAC (Access Key + Secret Key) authentication. Supports AWS S3, + MinIO, Cloudflare R2, and any service that speaks the S3 API. + + Config dict keys: + access_key (str): Overrides AWS_ACCESS_KEY_ID. + secret_key (str): Overrides AWS_SECRET_ACCESS_KEY. + endpoint_url (str): Optional; overrides AWS_S3_ENDPOINT. Leave empty for AWS S3. + region (str): Optional; overrides AWS_REGION. Default: us-east-1. + bucket_names (list[str]): Buckets to ingest from. If empty, all accessible buckets are used. + connection_id (str): Connection identifier used for logging. + """ + + CONNECTOR_NAME = "Amazon S3" + CONNECTOR_DESCRIPTION = "Add knowledge from Amazon S3 or any S3-compatible storage" + CONNECTOR_ICON = "aws-s3" + + CLIENT_ID_ENV_VAR = "AWS_ACCESS_KEY_ID" + CLIENT_SECRET_ENV_VAR = "AWS_SECRET_ACCESS_KEY" + + def get_client_id(self) -> str: + """Return access key from config dict, or AWS_ACCESS_KEY_ID env var as fallback.""" + val = self.config.get("access_key") or os.getenv("AWS_ACCESS_KEY_ID") + if val: + return val + raise ValueError( + "S3 credentials not set. Provide 'access_key' in the connector config " + "or set the AWS_ACCESS_KEY_ID environment variable." + ) + + def get_client_secret(self) -> str: + """Return secret key from config dict, or AWS_SECRET_ACCESS_KEY env var as fallback.""" + val = self.config.get("secret_key") or os.getenv("AWS_SECRET_ACCESS_KEY") + if val: + return val + raise ValueError( + "S3 credentials not set. Provide 'secret_key' in the connector config " + "or set the AWS_SECRET_ACCESS_KEY environment variable." + ) + + def __init__(self, config: Dict[str, Any]): + if config is None: + config = {} + super().__init__(config) + + self.bucket_names: List[str] = config.get("bucket_names") or [] + self.prefix: str = config.get("prefix", "") + self.connection_id: str = config.get("connection_id", "default") + + self._resource = None # Lazy-initialised on first use + self._client = None + + def _get_resource(self): + if self._resource is None: + self._resource = create_s3_resource(self.config) + return self._resource + + def _get_client(self): + if self._client is None: + self._client = create_s3_client(self.config) + return self._client + + # ------------------------------------------------------------------ + # BaseConnector abstract method implementations + # ------------------------------------------------------------------ + + async def authenticate(self) -> bool: + """Validate credentials by listing accessible buckets.""" + try: + resource = self._get_resource() + list(resource.buckets.all()) + self._authenticated = True + logger.debug(f"S3 authenticated for connection {self.connection_id}") + return True + except Exception as exc: + logger.warning(f"S3 authentication failed: {exc}") + self._authenticated = False + return False + + def _resolve_bucket_names(self) -> List[str]: + """Return configured bucket names, or auto-discover all accessible buckets.""" + if self.bucket_names: + return self.bucket_names + try: + resource = self._get_resource() + buckets = [b.name for b in resource.buckets.all()] + logger.debug("S3 auto-discovered %d bucket(s)", len(buckets)) + return buckets + except Exception as exc: + logger.warning(f"S3 could not auto-discover buckets: {exc}") + return [] + + async def list_files( + self, + page_token: Optional[str] = None, + max_files: Optional[int] = None, + **kwargs, + ) -> Dict[str, Any]: + """List objects across all configured (or auto-discovered) buckets. + + Uses the boto3 resource API: Bucket.objects.all() handles pagination + internally so all objects are returned without manual continuation tokens. + + Returns: + dict with keys: + "files": list of file dicts (id, name, bucket, size, modified_time) + "next_page_token": always None (SDK handles pagination internally) + """ + resource = self._get_resource() + files: List[Dict[str, Any]] = [] + bucket_names = self._resolve_bucket_names() + + for bucket_name in bucket_names: + try: + bucket = resource.Bucket(bucket_name) + objects = ( + bucket.objects.filter(Prefix=self.prefix) + if self.prefix + else bucket.objects.all() + ) + for obj in objects: + if obj.key.endswith("/"): + continue + files.append( + { + "id": _make_file_id(bucket_name, obj.key), + "name": basename(obj.key) or obj.key, + "bucket": bucket_name, + "key": obj.key, + "size": obj.size, + "modified_time": obj.last_modified.isoformat() + if obj.last_modified + else None, + } + ) + if max_files and len(files) >= max_files: + return {"files": files, "next_page_token": None} + except Exception as exc: + logger.error("Failed to list objects in S3 bucket: %s", exc) + continue + + return {"files": files, "next_page_token": None} + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + """Download an object from S3 and return a ConnectorDocument. + + Args: + file_id: Composite ID in the form "::". + + Returns: + ConnectorDocument with content bytes, ACL, and metadata. + """ + bucket_name, key = _split_file_id(file_id) + resource = self._get_resource() + + response = resource.Object(bucket_name, key).get() + content: bytes = response["Body"].read() + + last_modified: datetime = response.get("LastModified") or datetime.now(timezone.utc) + size: int = response.get("ContentLength", len(content)) + + # Prefer filename extension over generic S3 content-type (often application/octet-stream) + raw_content_type = response.get("ContentType", "") + if raw_content_type and raw_content_type != "application/octet-stream": + mime_type: str = raw_content_type + else: + mime_type = mimetypes.guess_type(key)[0] or "application/octet-stream" + + filename = basename(key) or key + acl = await self._extract_acl(bucket_name, key) + + return ConnectorDocument( + id=file_id, + filename=filename, + mimetype=mime_type, + content=content, + source_url=f"s3://{bucket_name}/{key}", + acl=acl, + modified_time=last_modified, + created_time=last_modified, # S3 does not expose creation time + metadata={ + "s3_bucket": bucket_name, + "s3_key": key, + "size": size, + }, + ) + + async def _extract_acl(self, bucket: str, key: str) -> DocumentACL: + """Fetch object ACL from S3 and map it to DocumentACL. + + Falls back to a minimal ACL on failure (e.g. ACLs disabled on the bucket). + """ + try: + client = self._get_client() + acl_response = client.get_object_acl(Bucket=bucket, Key=key) + + owner_id: str = ( + acl_response.get("Owner", {}).get("DisplayName") + or acl_response.get("Owner", {}).get("ID") + or "" + ) + + allowed_users: List[str] = [] + for grant in acl_response.get("Grants", []): + grantee = grant.get("Grantee", {}) + permission = grant.get("Permission", "") + if permission in ("FULL_CONTROL", "READ"): + user_id = ( + grantee.get("DisplayName") + or grantee.get("ID") + or grantee.get("EmailAddress") + ) + if user_id and user_id not in allowed_users: + allowed_users.append(user_id) + + return DocumentACL( + owner=owner_id or None, + allowed_users=allowed_users, + allowed_groups=[], + ) + except Exception as exc: + logger.warning("Could not fetch S3 object ACL, using fallback: %s", exc) + return DocumentACL(owner=None, allowed_users=[], allowed_groups=[]) + + # ------------------------------------------------------------------ + # Webhook / subscription stubs (S3 event notifications are out of scope) + # ------------------------------------------------------------------ + + async def setup_subscription(self) -> str: + """No-op: S3 event notifications are out of scope for this connector.""" + return "" + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + """No-op: webhooks are not supported in this connector version.""" + return [] + + def extract_webhook_channel_id( + self, payload: Dict[str, Any], headers: Dict[str, str] + ) -> Optional[str]: + return None + + async def cleanup_subscription(self, subscription_id: str) -> bool: + """No-op: no subscription to clean up.""" + return True diff --git a/src/connectors/aws_s3/models.py b/src/connectors/aws_s3/models.py new file mode 100644 index 000000000..582111bd2 --- /dev/null +++ b/src/connectors/aws_s3/models.py @@ -0,0 +1,13 @@ +"""Pydantic request/response models for AWS S3 API endpoints.""" + +from typing import List, Optional +from pydantic import BaseModel + + +class S3ConfigureBody(BaseModel): + access_key: Optional[str] = None + secret_key: Optional[str] = None + endpoint_url: Optional[str] = None + region: Optional[str] = None + bucket_names: Optional[List[str]] = None + connection_id: Optional[str] = None diff --git a/src/connectors/aws_s3/support.py b/src/connectors/aws_s3/support.py new file mode 100644 index 000000000..ccaa069b7 --- /dev/null +++ b/src/connectors/aws_s3/support.py @@ -0,0 +1,51 @@ +"""Support helpers for AWS S3 API endpoints. + +Contains pure (non-async) business logic for credential resolution and +config dict construction, keeping the route handlers thin. +""" + +import os +from typing import Dict, Optional, Tuple + +from .models import S3ConfigureBody + + +def build_s3_config( + body: S3ConfigureBody, + existing_config: Dict, +) -> Tuple[Dict, Optional[str]]: + """Resolve S3 credentials and build the connection config dict. + + Resolution order for each credential: request body → environment variable + → existing connection config. + + Returns: + (conn_config, None) on success + ({}, error_message) on validation failure + """ + access_key = ( + body.access_key + or os.getenv("AWS_ACCESS_KEY_ID") + or existing_config.get("access_key") + ) + secret_key = ( + body.secret_key + or os.getenv("AWS_SECRET_ACCESS_KEY") + or existing_config.get("secret_key") + ) + + if not access_key or not secret_key: + return {}, "access_key and secret_key are required" + + conn_config: Dict = { + "access_key": access_key.strip(), + "secret_key": secret_key.strip(), + } + if body.endpoint_url: + conn_config["endpoint_url"] = body.endpoint_url.strip() + if body.region: + conn_config["region"] = body.region.strip() + if body.bucket_names is not None: + conn_config["bucket_names"] = body.bucket_names + + return conn_config, None diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index fd207be6e..65674bda6 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -1,4 +1,5 @@ import json +import os import uuid import aiofiles from typing import Dict, List, Any, Optional @@ -13,6 +14,8 @@ from .google_drive import GoogleDriveConnector from .sharepoint import SharePointConnector from .onedrive import OneDriveConnector +from .ibm_cos import IBMCOSConnector +from .aws_s3 import S3Connector @dataclass @@ -330,31 +333,71 @@ async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: logger.warning(f"Authentication failed for {connection_id}") return None - def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: - """Get available connector types with their metadata""" + def get_available_connector_types( + self, user_id: Optional[str] = None + ) -> Dict[str, Dict[str, Any]]: + """Get available connector types with their metadata. + + Availability is user-scoped when ``user_id`` is provided: + a connector is considered available if either: + 1) its required env credentials are present, or + 2) the user has an active saved connection with usable credentials. + """ return { "google_drive": { "name": GoogleDriveConnector.CONNECTOR_NAME, "description": GoogleDriveConnector.CONNECTOR_DESCRIPTION, "icon": GoogleDriveConnector.CONNECTOR_ICON, - "available": self._is_connector_available("google_drive"), + "available": self._is_connector_available("google_drive", user_id), }, "sharepoint": { "name": SharePointConnector.CONNECTOR_NAME, "description": SharePointConnector.CONNECTOR_DESCRIPTION, "icon": SharePointConnector.CONNECTOR_ICON, - "available": self._is_connector_available("sharepoint"), + "available": self._is_connector_available("sharepoint", user_id), }, "onedrive": { "name": OneDriveConnector.CONNECTOR_NAME, "description": OneDriveConnector.CONNECTOR_DESCRIPTION, "icon": OneDriveConnector.CONNECTOR_ICON, - "available": self._is_connector_available("onedrive"), + "available": self._is_connector_available("onedrive", user_id), + }, + "ibm_cos": { + "name": IBMCOSConnector.CONNECTOR_NAME, + "description": IBMCOSConnector.CONNECTOR_DESCRIPTION, + "icon": IBMCOSConnector.CONNECTOR_ICON, + "available": os.environ.get("IBM_AUTH_ENABLED", "").lower() in ("1", "true", "yes"), + }, + "aws_s3": { + "name": S3Connector.CONNECTOR_NAME, + "description": S3Connector.CONNECTOR_DESCRIPTION, + "icon": S3Connector.CONNECTOR_ICON, + "available": os.environ.get("IBM_AUTH_ENABLED", "").lower() in ("1", "true", "yes"), }, } - def _is_connector_available(self, connector_type: str) -> bool: - """Check if a connector type is available (has required env vars)""" + def _has_saved_credentials_for_user( + self, connector_type: str, user_id: Optional[str] + ) -> bool: + """Check if user has an active saved connection with usable credentials.""" + for connection in self.connections.values(): + if connection.connector_type != connector_type or not connection.is_active: + continue + if user_id is not None and connection.user_id != user_id: + continue + try: + connector = self._create_connector(connection) + connector.get_client_id() + connector.get_client_secret() + return True + except (ValueError, NotImplementedError, RuntimeError): + continue + return False + + def _is_connector_available( + self, connector_type: str, user_id: Optional[str] = None + ) -> bool: + """Check whether connector is available for use by the given user.""" try: temp_config = ConnectionConfig( connection_id="temp", @@ -367,8 +410,9 @@ def _is_connector_available(self, connector_type: str) -> bool: connector.get_client_id() connector.get_client_secret() return True - except (ValueError, NotImplementedError): - return False + except (ValueError, NotImplementedError, RuntimeError): + # Fallback: saved per-user connection config (e.g. aws_s3 / ibm_cos) + return self._has_saved_credentials_for_user(connector_type, user_id) def _create_connector(self, config: ConnectionConfig) -> BaseConnector: """Factory method to create connector instances""" @@ -379,6 +423,10 @@ def _create_connector(self, config: ConnectionConfig) -> BaseConnector: return SharePointConnector(config.config) elif config.connector_type == "onedrive": return OneDriveConnector(config.config) + elif config.connector_type == "ibm_cos": + return IBMCOSConnector(config.config) + elif config.connector_type == "aws_s3": + return S3Connector(config.config) elif config.connector_type == "box": raise NotImplementedError("Box connector not implemented yet") elif config.connector_type == "dropbox": diff --git a/src/connectors/ibm_cos/__init__.py b/src/connectors/ibm_cos/__init__.py new file mode 100644 index 000000000..5bf01ad43 --- /dev/null +++ b/src/connectors/ibm_cos/__init__.py @@ -0,0 +1,17 @@ +from .connector import IBMCOSConnector +from .models import IBMCOSConfigureBody +from .api import ( + ibm_cos_defaults, + ibm_cos_configure, + ibm_cos_list_buckets, + ibm_cos_bucket_status, +) + +__all__ = [ + "IBMCOSConnector", + "IBMCOSConfigureBody", + "ibm_cos_defaults", + "ibm_cos_configure", + "ibm_cos_list_buckets", + "ibm_cos_bucket_status", +] diff --git a/src/connectors/ibm_cos/api.py b/src/connectors/ibm_cos/api.py new file mode 100644 index 000000000..ee4ef97b8 --- /dev/null +++ b/src/connectors/ibm_cos/api.py @@ -0,0 +1,202 @@ +"""FastAPI route handlers for IBM COS-specific endpoints.""" + +import os + +from fastapi import Depends +from fastapi.responses import JSONResponse + +from config.settings import get_index_name +from dependencies import get_connector_service, get_session_manager, get_current_user +from session_manager import User +from utils.logging_config import get_logger + +from .auth import create_ibm_cos_client, create_ibm_cos_resource +from .models import IBMCOSConfigureBody +from .support import build_ibm_cos_config + +logger = get_logger(__name__) + + +async def ibm_cos_defaults( + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """Return current IBM COS env-var defaults for pre-filling the config dialog. + + Sensitive values (API key, HMAC secret) are masked — only whether they are + set is returned, not the actual values. + """ + api_key = os.getenv("IBM_COS_API_KEY", "") + service_instance_id = os.getenv("IBM_COS_SERVICE_INSTANCE_ID", "") + endpoint = os.getenv("IBM_COS_ENDPOINT", "") + hmac_access_key = os.getenv("IBM_COS_HMAC_ACCESS_KEY_ID", "") + hmac_secret_key = os.getenv("IBM_COS_HMAC_SECRET_ACCESS_KEY", "") + disable_iam = os.getenv("OPENRAG_IBM_COS_IAM_UI", "").lower() not in ("1", "true", "yes") + + connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, connector_type="ibm_cos" + ) + conn_config = connections[0].config or {} if connections else {} + + def _pick(conn_key, env_val): + """Prefer connection config value over env var.""" + return conn_config.get(conn_key) or env_val + + return JSONResponse({ + "api_key_set": bool(api_key or conn_config.get("api_key")), + "service_instance_id": _pick("service_instance_id", service_instance_id), + "endpoint": _pick("endpoint_url", endpoint), + "hmac_access_key_set": bool(hmac_access_key or conn_config.get("hmac_access_key")), + "hmac_secret_key_set": bool(hmac_secret_key or conn_config.get("hmac_secret_key")), + "auth_mode": conn_config.get( + "auth_mode", + "hmac" if (disable_iam or not (api_key or conn_config.get("api_key"))) else "iam", + ), + "disable_iam": disable_iam, + "bucket_names": conn_config.get("bucket_names", []), + "connection_id": connections[0].connection_id if connections else None, + }) + + +async def ibm_cos_configure( + body: IBMCOSConfigureBody, + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """Create or update an IBM COS connection with explicit credentials. + + Tests the credentials by listing buckets, then persists the connection. + Credentials are stored in the connection config dict (not env vars) so + the connector works even without system-level env vars. + """ + existing_connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, connector_type="ibm_cos" + ) + existing_config = existing_connections[0].config if existing_connections else {} + + conn_config, error = build_ibm_cos_config(body, existing_config) + if error: + return JSONResponse({"error": error}, status_code=400) + + # Test credentials — IAM uses client (avoids ibm_botocore discovery-call bug), + # HMAC uses resource (S3-compatible, works with MinIO). + try: + if conn_config.get("auth_mode", "iam") == "hmac": + cos = create_ibm_cos_resource(conn_config) + list(cos.buckets.all()) + else: + cos = create_ibm_cos_client(conn_config) + cos.list_buckets() + except Exception: + logger.exception("Failed to connect to IBM COS during credential test.") + return JSONResponse( + {"error": "Could not connect to IBM COS with the provided configuration."}, + status_code=400, + ) + + # Persist: update existing connection or create a new one + if body.connection_id: + existing = await connector_service.connection_manager.get_connection(body.connection_id) + if existing and existing.user_id == user.user_id: + await connector_service.connection_manager.update_connection( + connection_id=body.connection_id, + config=conn_config, + ) + connector_service.connection_manager.active_connectors.pop(body.connection_id, None) + return JSONResponse({"connection_id": body.connection_id, "status": "connected"}) + + connection_id = await connector_service.connection_manager.create_connection( + connector_type="ibm_cos", + name="IBM Cloud Object Storage", + config=conn_config, + user_id=user.user_id, + ) + return JSONResponse({"connection_id": connection_id, "status": "connected"}) + + +async def ibm_cos_list_buckets( + connection_id: str, + connector_service=Depends(get_connector_service), + user: User = Depends(get_current_user), +): + """List all buckets accessible with the stored IBM COS credentials.""" + connection = await connector_service.connection_manager.get_connection(connection_id) + if not connection or connection.user_id != user.user_id: + return JSONResponse({"error": "Connection not found"}, status_code=404) + if connection.connector_type != "ibm_cos": + return JSONResponse({"error": "Not an IBM COS connection"}, status_code=400) + + try: + cfg = connection.config + if cfg.get("auth_mode", "iam") == "hmac": + cos = create_ibm_cos_resource(cfg) + buckets = [b.name for b in cos.buckets.all()] + else: + cos = create_ibm_cos_client(cfg) + buckets = [b["Name"] for b in cos.list_buckets().get("Buckets", [])] + return JSONResponse({"buckets": buckets}) + except Exception: + logger.exception("Failed to list IBM COS buckets for connection %s", connection_id) + return JSONResponse({"error": "Failed to list buckets"}, status_code=500) + + +async def ibm_cos_bucket_status( + connection_id: str, + connector_service=Depends(get_connector_service), + session_manager=Depends(get_session_manager), + user: User = Depends(get_current_user), +): + """Return all buckets for an IBM COS connection with their ingestion status. + + Each entry includes the bucket name, whether it has been ingested (is_synced), + and the count of indexed documents from that bucket. + """ + connection = await connector_service.connection_manager.get_connection(connection_id) + if not connection or connection.user_id != user.user_id: + return JSONResponse({"error": "Connection not found"}, status_code=404) + if connection.connector_type != "ibm_cos": + return JSONResponse({"error": "Not an IBM COS connection"}, status_code=400) + + # 1. List all buckets from COS + try: + cfg = connection.config + cos = create_ibm_cos_resource(cfg) + all_buckets = [b.name for b in cos.buckets.all()] + except Exception: + logger.exception("Failed to list IBM COS buckets for connection %s", connection_id) + return JSONResponse({"error": "Failed to list buckets"}, status_code=500) + + # 2. Count indexed documents per bucket from OpenSearch + ingested_counts: dict = {} + try: + opensearch_client = session_manager.get_user_opensearch_client( + user.user_id, user.jwt_token + ) + query_body = { + "size": 0, + "query": {"term": {"connector_type": "ibm_cos"}}, + "aggs": { + "doc_ids": { + "terms": {"field": "document_id", "size": 50000} + } + }, + } + index_name = get_index_name() + os_resp = opensearch_client.search(index=index_name, body=query_body) + for bucket_entry in os_resp.get("aggregations", {}).get("doc_ids", {}).get("buckets", []): + doc_id = bucket_entry["key"] + if "::" in doc_id: + bucket_name = doc_id.split("::")[0] + ingested_counts[bucket_name] = ingested_counts.get(bucket_name, 0) + 1 + except Exception: + pass # OpenSearch unavailable — show zero counts + + result = [ + { + "name": bucket, + "ingested_count": ingested_counts.get(bucket, 0), + "is_synced": ingested_counts.get(bucket, 0) > 0, + } + for bucket in all_buckets + ] + return JSONResponse({"buckets": result}) diff --git a/src/connectors/ibm_cos/auth.py b/src/connectors/ibm_cos/auth.py new file mode 100644 index 000000000..54012d724 --- /dev/null +++ b/src/connectors/ibm_cos/auth.py @@ -0,0 +1,183 @@ +"""IBM Cloud Object Storage authentication and client factory.""" + +import os +from typing import Dict, Any + +from utils.logging_config import get_logger + +logger = get_logger(__name__) + +# IAM auth endpoint default +_DEFAULT_AUTH_ENDPOINT = "https://iam.cloud.ibm.com/identity/token" + + +def _resolve_credentials(config: Dict[str, Any]): + """Resolve IBM COS credentials from config dict → environment variable fallback. + + Returns a dict with the resolved values needed to build a boto3 client/resource. + Raises ValueError if neither IAM nor HMAC credentials are available. + """ + endpoint_url = config.get("endpoint_url") or os.getenv("IBM_COS_ENDPOINT") + if not endpoint_url: + raise ValueError( + "IBM COS endpoint URL is required. Set IBM_COS_ENDPOINT or provide " + "'endpoint_url' in the connector config." + ) + + api_key = config.get("api_key") or os.getenv("IBM_COS_API_KEY") + service_instance_id = ( + config.get("service_instance_id") or os.getenv("IBM_COS_SERVICE_INSTANCE_ID") + ) + hmac_access_key = ( + config.get("hmac_access_key") or os.getenv("IBM_COS_HMAC_ACCESS_KEY_ID") + ) + hmac_secret_key = ( + config.get("hmac_secret_key") or os.getenv("IBM_COS_HMAC_SECRET_ACCESS_KEY") + ) + auth_endpoint = ( + config.get("auth_endpoint") + or os.getenv("IBM_COS_AUTH_ENDPOINT") + or _DEFAULT_AUTH_ENDPOINT + ) + + return { + "endpoint_url": endpoint_url, + "api_key": api_key, + "service_instance_id": service_instance_id, + "hmac_access_key": hmac_access_key, + "hmac_secret_key": hmac_secret_key, + "auth_endpoint": auth_endpoint, + } + + +def _build_resource(config: Dict[str, Any], creds: Dict[str, Any]): + """Build an S3-compatible resource using resolved credentials. + + HMAC mode uses standard boto3 (no IBM-specific calls, pure S3 protocol). + IAM mode uses ibm_boto3 with OAuth signature. + """ + auth_mode = config.get("auth_mode", "iam") + + if auth_mode == "hmac": + if not (creds["hmac_access_key"] and creds["hmac_secret_key"]): + raise ValueError( + "HMAC mode requires hmac_access_key and hmac_secret_key." + ) + try: + import boto3 + except ImportError as exc: + raise ImportError( + "boto3 is required for IBM COS HMAC mode. " + "Install it with: pip install boto3" + ) from exc + logger.debug("Creating IBM COS resource with HMAC authentication (boto3)") + return boto3.resource( + "s3", + aws_access_key_id=creds["hmac_access_key"], + aws_secret_access_key=creds["hmac_secret_key"], + endpoint_url=creds["endpoint_url"], + ) + + # IAM mode (default) — requires ibm_boto3 for OAuth token handling + try: + import ibm_boto3 + from ibm_botocore.client import Config + except ImportError as exc: + raise ImportError( + "ibm-cos-sdk is required for IBM COS IAM mode. " + "Install it with: pip install ibm-cos-sdk" + ) from exc + if not (creds["api_key"] and creds["service_instance_id"]): + raise ValueError( + "IAM mode requires api_key and service_instance_id." + ) + logger.debug("Creating IBM COS resource with IAM authentication (ibm_boto3)") + return ibm_boto3.resource( + "s3", + ibm_api_key_id=creds["api_key"], + ibm_service_instance_id=creds["service_instance_id"], + ibm_auth_endpoint=creds["auth_endpoint"], + config=Config(signature_version="oauth"), + endpoint_url=creds["endpoint_url"], + ) + + +def _build_client(config: Dict[str, Any], creds: Dict[str, Any]): + """Build an S3-compatible client using resolved credentials. + + HMAC mode uses standard boto3 (no IBM-specific calls, pure S3 protocol). + IAM mode uses ibm_boto3 with OAuth signature. + """ + auth_mode = config.get("auth_mode", "iam") + + if auth_mode == "hmac": + if not (creds["hmac_access_key"] and creds["hmac_secret_key"]): + raise ValueError( + "HMAC mode requires hmac_access_key and hmac_secret_key." + ) + try: + import boto3 + except ImportError as exc: + raise ImportError( + "boto3 is required for IBM COS HMAC mode. " + "Install it with: pip install boto3" + ) from exc + logger.debug("Creating IBM COS client with HMAC authentication (boto3)") + return boto3.client( + "s3", + aws_access_key_id=creds["hmac_access_key"], + aws_secret_access_key=creds["hmac_secret_key"], + endpoint_url=creds["endpoint_url"], + ) + + # IAM mode (default) — requires ibm_boto3 for OAuth token handling + try: + import ibm_boto3 + from ibm_botocore.client import Config + except ImportError as exc: + raise ImportError( + "ibm-cos-sdk is required for IBM COS IAM mode. " + "Install it with: pip install ibm-cos-sdk" + ) from exc + if not (creds["api_key"] and creds["service_instance_id"]): + raise ValueError( + "IAM mode requires api_key and service_instance_id." + ) + logger.debug("Creating IBM COS client with IAM authentication (ibm_boto3)") + return ibm_boto3.client( + "s3", + ibm_api_key_id=creds["api_key"], + ibm_service_instance_id=creds["service_instance_id"], + ibm_auth_endpoint=creds["auth_endpoint"], + config=Config(signature_version="oauth"), + endpoint_url=creds["endpoint_url"], + ) + + +def create_ibm_cos_resource(config: Dict[str, Any]): + """Return an S3 resource handle (high-level API). + + HMAC mode returns a standard boto3.resource (pure S3, no IBM discovery calls). + IAM mode returns an ibm_boto3.resource (OAuth token handling). + + Auth mode is determined by config["auth_mode"]: + - "iam" (default): IBM_COS_API_KEY + IBM_COS_SERVICE_INSTANCE_ID + - "hmac": IBM_COS_HMAC_ACCESS_KEY_ID + IBM_COS_HMAC_SECRET_ACCESS_KEY + + Resolution order for each credential: config dict → environment variable. + """ + creds = _resolve_credentials(config) + return _build_resource(config, creds) + + +def create_ibm_cos_client(config: Dict[str, Any]): + """Return an S3 low-level client. + + HMAC mode returns a standard boto3.client (pure S3, no IBM discovery calls). + IAM mode returns an ibm_boto3.client (OAuth token handling). + + Used by API endpoints that need raw client operations (e.g. get_object_acl). + For bucket/object listing and download, prefer create_ibm_cos_resource(). + """ + creds = _resolve_credentials(config) + return _build_client(config, creds) diff --git a/src/connectors/ibm_cos/connector.py b/src/connectors/ibm_cos/connector.py new file mode 100644 index 000000000..46d21a2ee --- /dev/null +++ b/src/connectors/ibm_cos/connector.py @@ -0,0 +1,375 @@ +"""IBM Cloud Object Storage connector for OpenRAG.""" + +import mimetypes +import os +from datetime import datetime, timezone +from posixpath import basename +from typing import Any, Dict, List, Optional + +from connectors.base import BaseConnector, ConnectorDocument, DocumentACL +from utils.logging_config import get_logger + +from .auth import create_ibm_cos_client, create_ibm_cos_resource + +logger = get_logger(__name__) + +# Separator used in composite file IDs: "::" +_ID_SEPARATOR = "::" + + +def _make_file_id(bucket: str, key: str) -> str: + return f"{bucket}{_ID_SEPARATOR}{key}" + + +def _split_file_id(file_id: str): + """Split a composite file ID into (bucket, key). Raises ValueError if invalid.""" + if _ID_SEPARATOR not in file_id: + raise ValueError(f"Invalid IBM COS file ID (missing separator): {file_id!r}") + bucket, key = file_id.split(_ID_SEPARATOR, 1) + return bucket, key + + +class IBMCOSConnector(BaseConnector): + """Connector for IBM Cloud Object Storage. + + Supports IAM (API key) and HMAC credential modes. Credentials are read + from the connector config dict first, then from environment variables. + + Config dict keys: + bucket_names (list[str]): Buckets to ingest from. Required. + prefix (str): Optional object key prefix filter. + endpoint_url (str): Overrides IBM_COS_ENDPOINT. + api_key (str): Overrides IBM_COS_API_KEY. + service_instance_id (str): Overrides IBM_COS_SERVICE_INSTANCE_ID. + hmac_access_key (str): HMAC mode – overrides IBM_COS_HMAC_ACCESS_KEY_ID. + hmac_secret_key (str): HMAC mode – overrides IBM_COS_HMAC_SECRET_ACCESS_KEY. + connection_id (str): Connection identifier used for logging. + """ + + CONNECTOR_NAME = "IBM Cloud Object Storage" + CONNECTOR_DESCRIPTION = "Add knowledge from IBM Cloud Object Storage" + CONNECTOR_ICON = "ibm-cos" + + # BaseConnector uses these to check env-var availability for IAM mode. + # HMAC-only setups will show as "unavailable" in the UI but can still be + # used when credentials are supplied in the config dict directly. + CLIENT_ID_ENV_VAR = "IBM_COS_API_KEY" + CLIENT_SECRET_ENV_VAR = "IBM_COS_SERVICE_INSTANCE_ID" + + def get_client_id(self) -> str: + """Return IAM API key, or HMAC access key ID as fallback.""" + val = ( + self.config.get("api_key") + or self.config.get("hmac_access_key") + or os.getenv("IBM_COS_API_KEY") + or os.getenv("IBM_COS_HMAC_ACCESS_KEY_ID") + ) + if val: + return val + raise ValueError( + "IBM COS credentials not set. Provide IBM_COS_API_KEY (IAM) " + "or IBM_COS_HMAC_ACCESS_KEY_ID (HMAC)." + ) + + def get_client_secret(self) -> str: + """Return IAM service instance ID, or HMAC secret key as fallback.""" + val = ( + self.config.get("service_instance_id") + or self.config.get("hmac_secret_key") + or os.getenv("IBM_COS_SERVICE_INSTANCE_ID") + or os.getenv("IBM_COS_HMAC_SECRET_ACCESS_KEY") + ) + if val: + return val + raise ValueError( + "IBM COS credentials not set. Provide IBM_COS_SERVICE_INSTANCE_ID (IAM) " + "or IBM_COS_HMAC_SECRET_ACCESS_KEY (HMAC)." + ) + + def __init__(self, config: Dict[str, Any]): + if config is None: + config = {} + super().__init__(config) + + self.bucket_names: List[str] = config.get("bucket_names") or [] + self.prefix: str = config.get("prefix", "") + self.connection_id: str = config.get("connection_id", "default") + + # Resolved service instance ID used as ACL owner fallback + self._service_instance_id: str = ( + config.get("service_instance_id") + or os.getenv("IBM_COS_SERVICE_INSTANCE_ID", "") + ) + + self._handle = None # Lazy-initialised on first use + # IAM mode uses ibm_boto3.client to avoid internal service-instance + # discovery calls that cause XML-parse errors against the real IBM COS API. + # HMAC mode uses ibm_boto3.resource (confirmed working with MinIO and S3). + self._is_hmac: bool = (config.get("auth_mode", "iam") == "hmac") + + def _get_handle(self): + """Return (and cache) the appropriate boto3 handle for the configured auth mode. + + - HMAC → ibm_boto3.resource (S3-compatible, works with MinIO) + - IAM → ibm_boto3.client (avoids ibm_botocore service-discovery calls + that break against the real IBM COS API) + """ + if self._handle is None: + if self._is_hmac: + self._handle = create_ibm_cos_resource(self.config) + else: + self._handle = create_ibm_cos_client(self.config) + return self._handle + + # ------------------------------------------------------------------ + # BaseConnector abstract method implementations + # ------------------------------------------------------------------ + + async def authenticate(self) -> bool: + """Validate credentials by listing buckets on the COS service.""" + try: + handle = self._get_handle() + if self._is_hmac: + list(handle.buckets.all()) # resource API + else: + handle.list_buckets() # client API + self._authenticated = True + logger.debug(f"IBM COS authenticated for connection {self.connection_id}") + return True + except Exception as exc: + logger.warning(f"IBM COS authentication failed: {exc}") + self._authenticated = False + return False + + def _resolve_bucket_names(self) -> List[str]: + """Return configured bucket names, or auto-discover all accessible buckets.""" + if self.bucket_names: + return self.bucket_names + try: + handle = self._get_handle() + if self._is_hmac: + buckets = [b.name for b in handle.buckets.all()] + else: + resp = handle.list_buckets() + buckets = [b["Name"] for b in resp.get("Buckets", [])] + logger.debug("IBM COS auto-discovered %d bucket(s)", len(buckets)) + return buckets + except Exception as exc: + logger.warning(f"IBM COS could not auto-discover buckets: {exc}") + return [] + + async def list_files( + self, + page_token: Optional[str] = None, + max_files: Optional[int] = None, + **kwargs, + ) -> Dict[str, Any]: + """List objects across all configured (or auto-discovered) buckets. + + Uses the ibm_boto3 resource API: Bucket.objects.all() handles pagination + internally so all objects are returned without manual continuation tokens. + + If no bucket_names are configured, all accessible buckets are used. + + Returns: + dict with keys: + "files": list of file dicts (id, name, bucket, size, modified_time) + "next_page_token": always None (SDK handles pagination internally) + """ + handle = self._get_handle() + files: List[Dict[str, Any]] = [] + bucket_names = self._resolve_bucket_names() + + for bucket_name in bucket_names: + try: + if self._is_hmac: + # resource API: Bucket.objects.all() handles pagination internally + bucket = handle.Bucket(bucket_name) + objects = ( + bucket.objects.filter(Prefix=self.prefix) + if self.prefix + else bucket.objects.all() + ) + for obj in objects: + if obj.key.endswith("/"): + continue + files.append( + { + "id": _make_file_id(bucket_name, obj.key), + "name": basename(obj.key) or obj.key, + "bucket": bucket_name, + "key": obj.key, + "size": obj.size, + "modified_time": obj.last_modified.isoformat() + if obj.last_modified + else None, + } + ) + if max_files and len(files) >= max_files: + return {"files": files, "next_page_token": None} + else: + # client API: list_objects_v2 with manual pagination + kwargs: Dict[str, Any] = {"Bucket": bucket_name} + if self.prefix: + kwargs["Prefix"] = self.prefix + while True: + resp = handle.list_objects_v2(**kwargs) + for obj in resp.get("Contents", []): + key = obj["Key"] + if key.endswith("/"): + continue + files.append( + { + "id": _make_file_id(bucket_name, key), + "name": basename(key) or key, + "bucket": bucket_name, + "key": key, + "size": obj.get("Size", 0), + "modified_time": obj["LastModified"].isoformat() + if obj.get("LastModified") + else None, + } + ) + if max_files and len(files) >= max_files: + return {"files": files, "next_page_token": None} + if resp.get("IsTruncated"): + kwargs["ContinuationToken"] = resp["NextContinuationToken"] + else: + break + + except Exception as exc: + logger.error("Failed to list objects in IBM COS bucket: %s", exc) + continue + + return {"files": files, "next_page_token": None} + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + """Download an object from IBM COS and return a ConnectorDocument. + + Uses the ibm_boto3 resource API: Object.get() downloads content and + returns all metadata (ContentType, ContentLength, LastModified) in one call. + + Args: + file_id: Composite ID in the form "::". + + Returns: + ConnectorDocument with content bytes, ACL, and metadata. + """ + bucket_name, key = _split_file_id(file_id) + handle = self._get_handle() + + # Both client.get_object() and resource.Object().get() return the same + # response dict: Body stream + ContentType, ContentLength, LastModified. + if self._is_hmac: + response = handle.Object(bucket_name, key).get() # resource + else: + response = handle.get_object(Bucket=bucket_name, Key=key) # client + content: bytes = response["Body"].read() + + last_modified: datetime = response.get("LastModified") or datetime.now(timezone.utc) + size: int = response.get("ContentLength", len(content)) + + # MIME type detection: prefer filename extension over generic S3 content-type. + # IBM COS often stores "application/octet-stream" for all objects regardless + # of their real type, so we treat that as "unknown" and fall back to the + # extension-based guess which is more reliable for named files. + raw_content_type = response.get("ContentType", "") + if raw_content_type and raw_content_type != "application/octet-stream": + mime_type: str = raw_content_type + else: + mime_type = mimetypes.guess_type(key)[0] or "application/octet-stream" + + filename = basename(key) or key + + acl = await self._extract_acl(bucket_name, key) + + return ConnectorDocument( + id=file_id, + filename=filename, + mimetype=mime_type, + content=content, + source_url=f"cos://{bucket_name}/{key}", + acl=acl, + modified_time=last_modified, + created_time=last_modified, # IBM COS does not expose creation time + metadata={ + "ibm_cos_bucket": bucket_name, + "ibm_cos_key": key, + "size": size, + }, + ) + + async def _extract_acl(self, bucket: str, key: str) -> DocumentACL: + """Fetch object ACL from IBM COS and map it to DocumentACL. + + Falls back to a minimal ACL (owner = service instance ID) on failure. + """ + try: + handle = self._get_handle() + # For resource (HMAC), access the underlying client via meta.client. + # For client (IAM), call directly. + client = handle.meta.client if self._is_hmac else handle + acl_response = client.get_object_acl(Bucket=bucket, Key=key) + + owner_id: str = ( + acl_response.get("Owner", {}).get("DisplayName") + or acl_response.get("Owner", {}).get("ID") + or self._service_instance_id + ) + + allowed_users: List[str] = [] + for grant in acl_response.get("Grants", []): + grantee = grant.get("Grantee", {}) + permission = grant.get("Permission", "") + if permission in ("FULL_CONTROL", "READ"): + user_id = ( + grantee.get("DisplayName") + or grantee.get("ID") + or grantee.get("EmailAddress") + ) + if user_id and user_id not in allowed_users: + allowed_users.append(user_id) + + return DocumentACL( + owner=owner_id, + allowed_users=allowed_users, + allowed_groups=[], + ) + except Exception as exc: + logger.warning("Could not fetch IBM COS object ACL, using fallback: %s", exc) + return DocumentACL( + owner=self._service_instance_id or None, + allowed_users=[], + allowed_groups=[], + ) + + # ------------------------------------------------------------------ + # Webhook / subscription (stub — IBM COS events require IBM Event + # Notifications service; not in scope for this connector version) + # ------------------------------------------------------------------ + + async def setup_subscription(self) -> str: + """No-op: IBM COS event notifications are out of scope for this connector.""" + return "" + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + """No-op: webhooks are not supported in this connector version.""" + return [] + + def extract_webhook_channel_id( + self, payload: Dict[str, Any], headers: Dict[str, str] + ) -> Optional[str]: + return None + + async def cleanup_subscription(self, subscription_id: str) -> bool: + """No-op: no subscription to clean up.""" + return True + + + + +if __name__ == "__main__": + connector = IBMCOSConnector({}) + print(connector.authenticate()) + print(connector.list_files()) + # print(connector.get_file_content("test_cos.py")) \ No newline at end of file diff --git a/src/connectors/ibm_cos/models.py b/src/connectors/ibm_cos/models.py new file mode 100644 index 000000000..59f2324f9 --- /dev/null +++ b/src/connectors/ibm_cos/models.py @@ -0,0 +1,20 @@ +"""Pydantic request/response models for IBM COS API endpoints.""" + +from typing import List, Optional +from pydantic import BaseModel + + +class IBMCOSConfigureBody(BaseModel): + auth_mode: str # "iam" or "hmac" + endpoint: str + # IAM fields + api_key: Optional[str] = None + service_instance_id: Optional[str] = None + auth_endpoint: Optional[str] = None + # HMAC fields + hmac_access_key: Optional[str] = None + hmac_secret_key: Optional[str] = None + # Optional bucket selection + bucket_names: Optional[List[str]] = None + # Optional: update an existing connection + connection_id: Optional[str] = None diff --git a/src/connectors/ibm_cos/support.py b/src/connectors/ibm_cos/support.py new file mode 100644 index 000000000..25e5e5bcb --- /dev/null +++ b/src/connectors/ibm_cos/support.py @@ -0,0 +1,68 @@ +"""Support helpers for IBM COS API endpoints. + +Contains pure (non-async) business logic for credential resolution and +config dict construction, keeping the route handlers thin. +""" + +import os +from typing import Dict, Optional, Tuple + +from .models import IBMCOSConfigureBody + + +def build_ibm_cos_config( + body: IBMCOSConfigureBody, + existing_config: Dict, +) -> Tuple[Dict, Optional[str]]: + """Resolve IBM COS credentials and build the connection config dict. + + Resolution order for each credential: request body → environment variable + → existing connection config. + + Returns: + (conn_config, None) on success + ({}, error_message) on validation failure + """ + conn_config: Dict = { + "auth_mode": body.auth_mode, + "endpoint_url": body.endpoint, + } + + if body.auth_mode == "iam": + api_key = ( + body.api_key + or os.getenv("IBM_COS_API_KEY") + or existing_config.get("api_key") + ) + svc_id = ( + body.service_instance_id + or os.getenv("IBM_COS_SERVICE_INSTANCE_ID") + or existing_config.get("service_instance_id") + ) + if not api_key or not svc_id: + return {}, "IAM mode requires api_key and service_instance_id" + conn_config["api_key"] = api_key + conn_config["service_instance_id"] = svc_id + if body.auth_endpoint: + conn_config["auth_endpoint"] = body.auth_endpoint + else: + # HMAC mode + hmac_access = ( + body.hmac_access_key + or os.getenv("IBM_COS_HMAC_ACCESS_KEY_ID") + or existing_config.get("hmac_access_key") + ) + hmac_secret = ( + body.hmac_secret_key + or os.getenv("IBM_COS_HMAC_SECRET_ACCESS_KEY") + or existing_config.get("hmac_secret_key") + ) + if not hmac_access or not hmac_secret: + return {}, "HMAC mode requires hmac_access_key and hmac_secret_key" + conn_config["hmac_access_key"] = hmac_access + conn_config["hmac_secret_key"] = hmac_secret + + if body.bucket_names is not None: + conn_config["bucket_names"] = body.bucket_names + + return conn_config, None diff --git a/src/connectors/langflow_connector_service.py b/src/connectors/langflow_connector_service.py index f80df0b02..1c981ba5a 100644 --- a/src/connectors/langflow_connector_service.py +++ b/src/connectors/langflow_connector_service.py @@ -331,39 +331,40 @@ async def sync_specific_files( original_folder_ids = getattr(cfg, "folder_ids", None) expanded_file_ids = file_ids # Default to original IDs - - try: - # Set the file_ids we want to sync in the connector's config - if cfg is not None: + + # Only attempt folder expansion for connectors that use cfg-based filtering + # (Google Drive, OneDrive, SharePoint). Connectors without a cfg attribute + # (e.g. IBM COS) receive pre-filtered file IDs and must NOT call list_files() + # here — doing so would re-list all files from all buckets, overwriting the + # carefully selected IDs passed in. + if cfg is not None: + try: cfg.file_ids = file_ids # type: ignore cfg.folder_ids = None # type: ignore - # Get the expanded list of file IDs (folders will be expanded to their contents) - # This uses the connector's list_files() which calls _iter_selected_items() - result = await connector.list_files() - expanded_file_ids = [f["id"] for f in result.get("files", [])] + # Expand file IDs — folders become their individual file contents + result = await connector.list_files() + expanded_file_ids = [f["id"] for f in result.get("files", [])] + + if not expanded_file_ids: + logger.warning( + f"No files found after expanding file_ids. " + f"Original IDs: {file_ids}. This may indicate all IDs were folders " + f"with no contents, or files that were filtered out." + ) + # If we have file_infos with download URLs, use original file_ids + # (OneDrive sharing IDs can't be expanded but can be downloaded directly) + if file_infos: + logger.info("Using original file IDs with cached download URLs") + expanded_file_ids = file_ids + else: + raise ValueError("No files to sync after expanding folders") - if not expanded_file_ids: - logger.warning( - f"No files found after expanding file_ids. " - f"Original IDs: {file_ids}. This may indicate all IDs were folders " - f"with no contents, or files that were filtered out." - ) - # If we have file_infos with download URLs, use original file_ids - # (OneDrive sharing IDs can't be expanded but can be downloaded directly) - if file_infos: - logger.info("Using original file IDs with cached download URLs") - expanded_file_ids = file_ids - else: - raise ValueError("No files to sync after expanding folders") - - except Exception as e: - logger.error(f"Failed to expand file_ids via list_files(): {e}") - # Fallback to original file_ids if expansion fails - expanded_file_ids = file_ids - finally: - # Restore original config values - if cfg is not None: + except Exception as e: + logger.error(f"Failed to expand file_ids via list_files(): {e}") + # Fallback to original file_ids if expansion fails + expanded_file_ids = file_ids + finally: cfg.file_ids = original_file_ids # type: ignore cfg.folder_ids = original_folder_ids # type: ignore diff --git a/src/main.py b/src/main.py index d264d1959..c168f6c6a 100644 --- a/src/main.py +++ b/src/main.py @@ -44,6 +44,18 @@ ) from api.connector_router import ConnectorRouter +from connectors.ibm_cos.api import ( + ibm_cos_defaults, + ibm_cos_configure, + ibm_cos_list_buckets, + ibm_cos_bucket_status, +) +from connectors.aws_s3.api import ( + s3_defaults, + s3_configure, + s3_list_buckets, + s3_bucket_status, +) from services.api_key_service import APIKeyService from api import keys as api_keys from api.v1 import ( @@ -1492,45 +1504,23 @@ async def create_app(): ) # Connector endpoints - app.add_api_route( - "/connectors", connectors.list_connectors, methods=["GET"], tags=["internal"] - ) - app.add_api_route( - "/connectors/{connector_type}/sync", - connectors.connector_sync, - methods=["POST"], - tags=["internal"], - ) - app.add_api_route( - "/connectors/sync-all", - connectors.sync_all_connectors, - methods=["POST"], - tags=["internal"], - ) - app.add_api_route( - "/connectors/{connector_type}/status", - connectors.connector_status, - methods=["GET"], - tags=["internal"], - ) - app.add_api_route( - "/connectors/{connector_type}/token", - connectors.connector_token, - methods=["GET"], - tags=["internal"], - ) - app.add_api_route( - "/connectors/{connector_type}/disconnect", - connectors.connector_disconnect, - methods=["DELETE"], - tags=["internal"], - ) - app.add_api_route( - "/connectors/{connector_type}/webhook", - connectors.connector_webhook, - methods=["POST", "GET"], - tags=["internal"], - ) + app.add_api_route("/connectors", connectors.list_connectors, methods=["GET"], tags=["internal"]) + # IBM COS-specific routes (registered before generic /{connector_type}/... to avoid shadowing) + app.add_api_route("/connectors/ibm_cos/defaults", ibm_cos_defaults, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/ibm_cos/configure", ibm_cos_configure, methods=["POST"], tags=["internal"]) + app.add_api_route("/connectors/ibm_cos/{connection_id}/buckets", ibm_cos_list_buckets, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/ibm_cos/{connection_id}/bucket-status", ibm_cos_bucket_status, methods=["GET"], tags=["internal"]) + # AWS S3-specific routes (registered before generic /{connector_type}/... to avoid shadowing) + app.add_api_route("/connectors/aws_s3/defaults", s3_defaults, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/aws_s3/configure", s3_configure, methods=["POST"], tags=["internal"]) + app.add_api_route("/connectors/aws_s3/{connection_id}/buckets", s3_list_buckets, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/aws_s3/{connection_id}/bucket-status", s3_bucket_status, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/{connector_type}/sync", connectors.connector_sync, methods=["POST"], tags=["internal"]) + app.add_api_route("/connectors/sync-all", connectors.sync_all_connectors, methods=["POST"], tags=["internal"]) + app.add_api_route("/connectors/{connector_type}/status", connectors.connector_status, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/{connector_type}/token", connectors.connector_token, methods=["GET"], tags=["internal"]) + app.add_api_route("/connectors/{connector_type}/disconnect", connectors.connector_disconnect, methods=["DELETE"], tags=["internal"]) + app.add_api_route("/connectors/{connector_type}/webhook", connectors.connector_webhook, methods=["POST", "GET"], tags=["internal"]) # Document endpoints app.add_api_route( diff --git a/src/services/auth_service.py b/src/services/auth_service.py index 1adc26522..b9bdc9044 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -1,4 +1,3 @@ -import os import uuid import json import httpx @@ -20,6 +19,9 @@ from connectors.onedrive import OneDriveConnector from connectors.sharepoint import SharePointConnector +# Connectors that authenticate directly (no OAuth redirect required) +_DIRECT_AUTH_CONNECTORS = {"ibm_cos"} + class AuthService: def __init__(self, session_manager: SessionManager, connector_service=None, flows_service=None, langflow_mcp_service: LangflowMCPService | None = None): @@ -57,6 +59,7 @@ async def init_oauth( "google_drive", "onedrive", "sharepoint", + "ibm_cos", ]: raise ValueError(f"Unsupported connector type: {connector_type}") elif purpose not in ["app_auth", "data_source"]: @@ -92,6 +95,10 @@ async def init_oauth( ) ) + # Direct-auth connectors (HMAC/API-key based, no OAuth redirect) + if connector_type in _DIRECT_AUTH_CONNECTORS: + return await self._init_direct_connection(connector_type, connection_id) + # Get OAuth configuration from connector and OAuth classes import os @@ -148,6 +155,61 @@ def _assert_env_key(name, val): return {"connection_id": connection_id, "oauth_config": oauth_config} + async def _init_direct_connection( + self, connector_type: str, connection_id: str + ) -> dict: + """Authenticate a non-OAuth connector immediately using env var credentials. + + Creates the connection record (already done by the caller) and verifies + that the credentials work by calling authenticate() on the connector. + Returns a response without oauth_config so the frontend knows no redirect + is needed. + """ + try: + connection_config = ( + await self.connector_service.connection_manager.get_connection( + connection_id + ) + ) + if not connection_config: + raise ValueError("Connection not found") + + connector = self.connector_service.connection_manager._create_connector( + connection_config + ) + authenticated = await connector.authenticate() + if not authenticated: + # Remove the connection so the user can retry after fixing credentials + await self.connector_service.connection_manager.delete_connection( + connection_id + ) + raise ValueError( + f"Could not authenticate with {connector_type}. " + "Check that your credentials and endpoint are correct." + ) + + # Cache the authenticated connector + self.connector_service.connection_manager.active_connectors[ + connection_id + ] = connector + + except ValueError: + raise + except Exception as exc: + await self.connector_service.connection_manager.delete_connection( + connection_id + ) + raise ValueError( + f"Failed to connect {connector_type}: {exc}" + ) from exc + + return { + "connection_id": connection_id, + "status": "connected", + "connector_type": connector_type, + # No oauth_config — frontend must not attempt an OAuth redirect + } + async def handle_oauth_callback( self, connection_id: str, @@ -415,7 +477,7 @@ async def _handle_data_source_auth( else: logger.warning("_handle_data_source_auth: _detect_base_url returned None") else: - logger.warning(f"_handle_data_source_auth: Connector not available or doesn't have _detect_base_url") + logger.warning("_handle_data_source_auth: Connector not available or doesn't have _detect_base_url") # Clear the cached connector so next get_connector() creates a fresh instance # with the updated config (including base_url) @@ -431,6 +493,8 @@ async def _handle_data_source_auth( async def get_user_info(self, request) -> Optional[dict]: """Get current user information from request""" + from config.settings import IBM_AUTH_ENABLED + # In no-auth mode, return a consistent response if is_no_auth_mode(): return {"authenticated": False, "user": None, "no_auth_mode": True} @@ -440,6 +504,7 @@ async def get_user_info(self, request) -> Optional[dict]: if user: user_data = { "authenticated": True, + "ibm_auth_mode": IBM_AUTH_ENABLED, "user": { "user_id": user.user_id, "email": user.email, @@ -451,7 +516,7 @@ async def get_user_info(self, request) -> Optional[dict]: else None, }, } - + return user_data else: - return {"authenticated": False, "user": None} + return {"authenticated": False, "ibm_auth_mode": IBM_AUTH_ENABLED, "user": None} diff --git a/src/services/chat_service.py b/src/services/chat_service.py index d0451918a..7c5fb5302 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -607,20 +607,21 @@ async def delete_session(self, user_id: str, session_id: str): from agent import delete_user_conversation local_deleted = await delete_user_conversation(user_id, session_id) - # Delete from Langflow using the monitor API - langflow_deleted = await self._delete_langflow_session(session_id) - - success = local_deleted or langflow_deleted - error_msg = None + if not local_deleted: + return { + "success": False, + "not_found": True, + "error": "Conversation not found", + } - if not success: - error_msg = "Session not found in local storage or Langflow" + # Delete from Langflow using the monitor API (best-effort) + langflow_deleted = await self._delete_langflow_session(session_id) return { - "success": success, + "success": True, "local_deleted": local_deleted, "langflow_deleted": langflow_deleted, - "error": error_msg + "error": None, } except Exception as e: diff --git a/src/tui/config_fields.py b/src/tui/config_fields.py index c13e2e126..184365a32 100644 --- a/src/tui/config_fields.py +++ b/src/tui/config_fields.py @@ -181,8 +181,49 @@ class ConfigSection: "aws_secret_access_key", "AWS_SECRET_ACCESS_KEY", "Secret Access Key", placeholder="", secret=True, ), + ConfigField( + "aws_s3_endpoint", "AWS_S3_ENDPOINT", "S3 Endpoint URL (optional)", + placeholder="", + helper_text="Leave empty for AWS S3. For MinIO, R2, or other S3-compatible services, enter the endpoint URL.", + ), + ConfigField( + "aws_region", "AWS_REGION", "AWS Region (optional)", + placeholder="us-east-1", + default="us-east-1", + helper_text="AWS region (e.g. us-east-1, eu-west-1). Default: us-east-1.", + ), ], advanced=True, gate_prompt="Configure AWS credentials?"), + # ── IBM Cloud Object Storage ───────────────────────────────── + ConfigSection("IBM Cloud Object Storage", [ + ConfigField( + "ibm_cos_api_key", "IBM_COS_API_KEY", "API Key", + placeholder="", + helper_text="Create API key at https://cloud.ibm.com/iam/apikeys", + secret=True, + ), + ConfigField( + "ibm_cos_service_instance_id", "IBM_COS_SERVICE_INSTANCE_ID", + "Service Instance ID (CRN)", + placeholder="crn:v1:bluemix:...", + ), + ConfigField( + "ibm_cos_endpoint", "IBM_COS_ENDPOINT", "Service Endpoint", + placeholder="https://s3.us-south.cloud-object-storage.appdomain.cloud", + helper_text="Endpoints: https://cloud.ibm.com/docs/cloud-object-storage?topic=cloud-object-storage-endpoints", + ), + ConfigField( + "ibm_cos_hmac_access_key_id", "IBM_COS_HMAC_ACCESS_KEY_ID", + "HMAC Access Key ID (optional)", + placeholder="", + ), + ConfigField( + "ibm_cos_hmac_secret_access_key", "IBM_COS_HMAC_SECRET_ACCESS_KEY", + "HMAC Secret Access Key (optional)", + placeholder="", secret=True, + ), + ], advanced=True, gate_prompt="Configure IBM Cloud Object Storage?"), + # ── Langfuse ──────────────────────────────────────────────── ConfigSection("Langfuse", [ ConfigField( diff --git a/src/tui/managers/env_manager.py b/src/tui/managers/env_manager.py index a349ae04f..1a7780b65 100644 --- a/src/tui/managers/env_manager.py +++ b/src/tui/managers/env_manager.py @@ -58,8 +58,18 @@ class EnvConfig: webhook_base_url: str = "" aws_access_key_id: str = "" aws_secret_access_key: str = "" + aws_s3_endpoint: str = "" + aws_region: str = "" langflow_public_url: str = "" + # IBM Cloud Object Storage settings + ibm_cos_api_key: str = "" + ibm_cos_service_instance_id: str = "" + ibm_cos_endpoint: str = "" + ibm_cos_hmac_access_key_id: str = "" + ibm_cos_hmac_secret_access_key: str = "" + ibm_cos_auth_endpoint: str = "" # Optional: override IAM token endpoint + # Langfuse settings (optional) langfuse_secret_key: str = "" langfuse_public_key: str = "" @@ -186,7 +196,15 @@ def _env_attr_map(self) -> Dict[str, str]: "WEBHOOK_BASE_URL": "webhook_base_url", "AWS_ACCESS_KEY_ID": "aws_access_key_id", "AWS_SECRET_ACCESS_KEY": "aws_secret_access_key", # pragma: allowlist secret + "AWS_S3_ENDPOINT": "aws_s3_endpoint", + "AWS_REGION": "aws_region", "LANGFLOW_PUBLIC_URL": "langflow_public_url", + "IBM_COS_API_KEY": "ibm_cos_api_key", # pragma: allowlist secret + "IBM_COS_SERVICE_INSTANCE_ID": "ibm_cos_service_instance_id", + "IBM_COS_ENDPOINT": "ibm_cos_endpoint", + "IBM_COS_HMAC_ACCESS_KEY_ID": "ibm_cos_hmac_access_key_id", + "IBM_COS_HMAC_SECRET_ACCESS_KEY": "ibm_cos_hmac_secret_access_key", # pragma: allowlist secret + "IBM_COS_AUTH_ENDPOINT": "ibm_cos_auth_endpoint", "OPENRAG_DOCUMENTS_PATHS": "openrag_documents_paths", "OPENRAG_DOCUMENTS_PATH": "openrag_documents_path", "OPENRAG_KEYS_PATH": "openrag_keys_path", @@ -558,7 +576,15 @@ def save_env_file(self) -> bool: ("WEBHOOK_BASE_URL", self.config.webhook_base_url), ("AWS_ACCESS_KEY_ID", self.config.aws_access_key_id), ("AWS_SECRET_ACCESS_KEY", self.config.aws_secret_access_key), + ("AWS_S3_ENDPOINT", self.config.aws_s3_endpoint), + ("AWS_REGION", self.config.aws_region), ("LANGFLOW_PUBLIC_URL", self.config.langflow_public_url), + ("IBM_COS_API_KEY", self.config.ibm_cos_api_key), + ("IBM_COS_SERVICE_INSTANCE_ID", self.config.ibm_cos_service_instance_id), + ("IBM_COS_ENDPOINT", self.config.ibm_cos_endpoint), + ("IBM_COS_HMAC_ACCESS_KEY_ID", self.config.ibm_cos_hmac_access_key_id), + ("IBM_COS_HMAC_SECRET_ACCESS_KEY", self.config.ibm_cos_hmac_secret_access_key), + ("IBM_COS_AUTH_ENDPOINT", self.config.ibm_cos_auth_endpoint), ] optional_written = False diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py index 764f90694..44d0bb98e 100644 --- a/src/utils/file_utils.py +++ b/src/utils/file_utils.py @@ -61,7 +61,7 @@ def safe_unlink(path: str) -> None: def get_file_extension(mimetype: str) -> str: - """Get file extension based on MIME type""" + """Get file extension based on MIME type. Returns None if the type is unknown.""" mime_to_ext = { "application/pdf": ".pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", @@ -69,19 +69,33 @@ def get_file_extension(mimetype: str) -> str: "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", "application/vnd.ms-powerpoint": ".ppt", "text/plain": ".txt", + "text/markdown": ".md", + "text/x-markdown": ".md", "text/html": ".html", + "text/csv": ".csv", + "application/json": ".json", + "application/xml": ".xml", + "text/xml": ".xml", "application/rtf": ".rtf", "application/vnd.google-apps.document": ".pdf", # Exported as PDF "application/vnd.google-apps.presentation": ".pdf", "application/vnd.google-apps.spreadsheet": ".pdf", } - return mime_to_ext.get(mimetype, ".bin") + return mime_to_ext.get(mimetype) def clean_connector_filename(filename: str, mimetype: str) -> str: - """Clean filename and ensure correct extension""" - suffix = get_file_extension(mimetype) + """Clean filename and ensure correct extension. + + If the MIME type maps to a known extension, it is enforced. + If the MIME type is unknown, the original filename (and its extension) is kept as-is + rather than appending a meaningless .bin suffix. + """ clean_name = filename.replace(" ", "_").replace("/", "_") + suffix = get_file_extension(mimetype) + if suffix is None: + # Unknown type — keep whatever extension the file already has + return clean_name if not clean_name.lower().endswith(suffix.lower()): return clean_name + suffix return clean_name \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 1da7c20f4..ffe80467e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,14 @@ async def onboard_system(): This ensures the OpenRAG config is marked as edited and properly initialized so that tests can use the /settings endpoint. + + Skips in-process backend setup when SDK_TESTS_ONLY=true (SDK tests talk to + an already-running external stack and must not wipe its state). """ + if os.environ.get("SDK_TESTS_ONLY") == "true": + yield + return + from pathlib import Path import shutil @@ -34,7 +41,7 @@ async def onboard_system(): config_file = Path("config/config.yaml") if config_file.exists(): config_file.unlink() - + # Clean up OpenSearch data directory to ensure fresh state for tests opensearch_data_path = Path(os.getenv("OPENSEARCH_DATA_PATH", "./opensearch-data")) if opensearch_data_path.exists(): diff --git a/tests/integration/core/__init__.py b/tests/integration/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_api_endpoints.py b/tests/integration/core/test_api_endpoints.py similarity index 100% rename from tests/integration/test_api_endpoints.py rename to tests/integration/core/test_api_endpoints.py diff --git a/tests/integration/test_startup_ingest.py b/tests/integration/core/test_startup_ingest.py similarity index 100% rename from tests/integration/test_startup_ingest.py rename to tests/integration/core/test_startup_ingest.py diff --git a/tests/integration/sdk/README.md b/tests/integration/sdk/README.md new file mode 100644 index 000000000..2c40f2703 --- /dev/null +++ b/tests/integration/sdk/README.md @@ -0,0 +1,126 @@ +# OpenRAG Python SDK — QA Test Checklist + +Live integration tests against a running OpenRAG instance (`http://localhost:3000` by default). + +**Run all SDK tests:** +```bash +make test-sdk +``` + +--- + +## Authentication (`test_auth.py`) + +| # | Test | Expected | +|---|------|----------| +| 1 | Construct client with no API key | Raises `AuthenticationError` immediately | +| 2 | Send request with invalid API key | Raises `AuthenticationError` with status 401 or 403 | +| 3 | Send request with well-formed but non-existent key | Raises `AuthenticationError` | + +--- + +## Chat (`test_chat.py`) + +| # | Test | Expected | +|---|------|----------| +| 4 | Non-streaming chat | Returns non-empty response string | +| 5 | Streaming chat (`create(stream=True)`) | Yields content events with text deltas | +| 6 | Streaming via context manager (`stream()`) | Accumulated `stream.text` is non-empty | +| 7 | `text_stream` async iterator | Yields plain text chunks | +| 8 | `final_text()` | Returns full accumulated response | +| 9 | Conversation continuation (pass `chat_id`) | Second reply uses same conversation | +| 10 | List conversations | Returns list of conversations | +| 11 | Get conversation by ID | Returns conversation with message history | +| 12 | Delete existing conversation | Returns `True` | +| 13 | Chat with ingested document (RAG) | Response sources include the ingested file | +| 14 | Stream continuation with `chat_id` | Follow-up stream uses existing conversation | +| 15 | Every response includes `chat_id` | `chat_id` is a non-empty string | +| 16 | `chat_id` available after stream consumed | `stream.chat_id` is populated | +| 17 | `sources` field on response | Always a list (may be empty) | + +--- + +## Documents (`test_documents.py`) + +| # | Test | Expected | +|---|------|----------| +| 18 | Ingest file (async, `wait=False`) | Returns `task_id`; polling reaches terminal state | +| 19 | Ingest file (blocking, `wait=True`) | Returns terminal status with `successful_files >= 0` | +| 20 | Delete ingested document | `success=True`, `deleted_chunks > 0` | +| 21 | Delete never-ingested filename | `success=False`, `deleted_chunks=0`, error message present | +| 22 | Ingest via file object (`io.BytesIO`) | Accepted and processed without error | +| 23 | Re-ingest same filename twice | Does not raise; second call returns a status | +| 24 | Ingest `.md` file | Accepted and processed without error | +| 25 | Poll task status manually | `get_task_status()` returns a status; `wait_for_task()` returns `completed` or `failed` | + +--- + +## Search (`test_search.py`) + +| # | Test | Expected | +|---|------|----------| +| 26 | Basic search query | Returns a results list | +| 27 | Search with `limit=1` | Returns at most 1 result | +| 28 | Search with `score_threshold=0.99` | Returns a list (may be empty) without error | +| 29 | Nonsense/obscure query | Returns empty list, no error | +| 30 | Unicode and emoji in query | Returns list, no error | +| 31 | Result fields | Each result has `text` (non-empty string) | + +--- + +## Settings (`test_settings.py`) + +| # | Test | Expected | +|---|------|----------| +| 32 | Get settings | Response includes `agent` and `knowledge` sections | +| 33 | Update `chunk_size` setting | Update succeeds; value readable back unchanged | + +--- + +## Models (`test_models.py`) + +| # | Test | Expected | +|---|------|----------| +| 34 | List models for a provider (`openai`) | Returns `language_models` and `embedding_models` as lists | + +--- + +## Knowledge Filters (`test_filters.py`) + +| # | Test | Expected | +|---|------|----------| +| 35 | Create filter | `success=True`, `id` returned | +| 36 | Search filters by name | Returns list containing the created filter | +| 37 | Get filter by ID | Returns filter with correct `id` and `name` | +| 38 | Update filter description | Update returns `True`; description readable back | +| 39 | Delete filter | Returns `True` | +| 40 | Get deleted filter | Returns `None` | +| 41 | Pass `filter_id` to `chat.create()` | No error; response returned | +| 42 | Pass `filter_id` to `search.query()` | No error; results returned | + +--- + +## Error Handling (`test_errors.py`) + +| # | Test | Expected | +|---|------|----------| +| 43 | Connect to dead port | Raises a network exception within timeout | +| 44 | Get conversation with random UUID | Raises `NotFoundError` | +| 45 | Delete conversation with random UUID | Returns `False` | +| 46 | Update settings with invalid value (`chunk_size=-999999`) | Raises `OpenRAGError` subclass | +| 47 | Call `ingest()` with no arguments | Raises `ValueError` | +| 48 | Call `ingest()` with `BytesIO` but no filename | Raises `ValueError` | + +--- + +## End-to-End (`test_e2e.py`) + +| # | Test | Expected | +|---|------|----------| +| 49 | Full RAG pipeline: ingest → search → chat | Chat sources include the ingested document | +| 50 | Multi-turn conversation with RAG | Second turn uses same `chat_id`; context carried over | +| 51 | Knowledge filter scopes search and chat | Search and chat succeed with `filter_id`; filter cleaned up | + +--- + +**Total: 51 tests across 8 domains.** diff --git a/tests/integration/sdk/__init__.py b/tests/integration/sdk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/sdk/conftest.py b/tests/integration/sdk/conftest.py new file mode 100644 index 000000000..099ed1fc7 --- /dev/null +++ b/tests/integration/sdk/conftest.py @@ -0,0 +1,102 @@ +""" +Shared fixtures and setup for OpenRAG SDK integration tests. + +All tests in this directory require a running OpenRAG instance. +Set OPENRAG_URL (default: http://localhost:3000) before running. +""" + +import os +import uuid +from pathlib import Path + +import httpx +import pytest +import pytest_asyncio + +_cached_api_key: str | None = None +_base_url = os.environ.get("OPENRAG_URL", "http://localhost:3000") +_onboarding_done = False + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def ensure_onboarding(): + """Ensure the OpenRAG instance is onboarded before running tests. + + Uses httpx.AsyncClient so the async event loop is never blocked, + even on a slow or unreachable server. + """ + global _onboarding_done + if _onboarding_done: + return + + onboarding_payload = { + "llm_provider": "openai", + "embedding_provider": "openai", + "embedding_model": "text-embedding-3-small", + "llm_model": "gpt-4o-mini", + } + + try: + async with httpx.AsyncClient(timeout=30.0) as ac: + response = await ac.post( + f"{_base_url}/api/onboarding", + json=onboarding_payload, + ) + if response.status_code in (200, 204): + print("[SDK Tests] Onboarding completed successfully") + else: + print(f"[SDK Tests] Onboarding returned {response.status_code}: {response.text[:200]}") + except Exception as e: + print(f"[SDK Tests] Onboarding request failed: {e}") + + _onboarding_done = True + + +async def _fetch_api_key() -> str: + """Fetch or create a test API key from the running instance (async, cached).""" + global _cached_api_key + if _cached_api_key is not None: + return _cached_api_key + + async with httpx.AsyncClient(timeout=30.0) as ac: + response = await ac.post( + f"{_base_url}/api/keys", + json={"name": "SDK Integration Test"}, + ) + + if response.status_code == 401: + pytest.skip("Cannot create API key — authentication required") + + assert response.status_code == 200, f"Failed to create API key: {response.text}" + _cached_api_key = response.json()["api_key"] + return _cached_api_key + + +@pytest_asyncio.fixture +async def client(): + """OpenRAG client authenticated with a valid test API key.""" + from openrag_sdk import OpenRAGClient + + api_key = await _fetch_api_key() + c = OpenRAGClient(api_key=api_key, base_url=_base_url) + yield c + await c.close() + + +@pytest.fixture +def base_url() -> str: + """The base URL of the running OpenRAG instance.""" + return _base_url + + +@pytest.fixture +def test_file(tmp_path) -> Path: + """A uniquely-named markdown file ready for ingestion.""" + file_path = tmp_path / f"sdk_test_doc_{uuid.uuid4().hex[:8]}.md" + file_path.write_text( + f"# SDK Integration Test Document\n\n" + f"ID: {uuid.uuid4()}\n\n" + "This document tests the OpenRAG Python SDK.\n\n" + "It contains unique content about purple elephants dancing.\n" + ) + return file_path diff --git a/tests/integration/sdk/test_auth.py b/tests/integration/sdk/test_auth.py new file mode 100644 index 000000000..af233dbca --- /dev/null +++ b/tests/integration/sdk/test_auth.py @@ -0,0 +1,59 @@ +"""Tests for authentication and API key behaviour.""" + +import os + +import pytest + +from .conftest import _base_url + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestAuth: + """Test authentication and API key behaviour.""" + + def test_missing_api_key_raises_at_construction(self): + """Client must raise AuthenticationError immediately if no api_key is given.""" + from openrag_sdk import OpenRAGClient + from openrag_sdk.exceptions import AuthenticationError + + env_backup = os.environ.pop("OPENRAG_API_KEY", None) + try: + with pytest.raises(AuthenticationError): + OpenRAGClient() + finally: + if env_backup is not None: + os.environ["OPENRAG_API_KEY"] = env_backup + + @pytest.mark.asyncio + async def test_invalid_api_key_raises_auth_error(self): + """Requests with a bogus key must raise AuthenticationError (401/403).""" + from openrag_sdk import OpenRAGClient + from openrag_sdk.exceptions import AuthenticationError + + bad_client = OpenRAGClient(api_key="orag_invalid_key_for_testing", base_url=_base_url) + try: + with pytest.raises(AuthenticationError) as exc_info: + await bad_client.settings.get() + assert exc_info.value.status_code in (401, 403) + finally: + await bad_client.close() + + @pytest.mark.asyncio + async def test_revoked_api_key_raises_auth_error(self): + """A well-formed but non-existent key must be rejected.""" + from openrag_sdk import OpenRAGClient + from openrag_sdk.exceptions import AuthenticationError + + fake_client = OpenRAGClient( + api_key="orag_0000000000000000000000000000000000000000", + base_url=_base_url, + ) + try: + with pytest.raises(AuthenticationError): + await fake_client.chat.list() + finally: + await fake_client.close() diff --git a/tests/integration/sdk/test_chat.py b/tests/integration/sdk/test_chat.py new file mode 100644 index 000000000..501337c87 --- /dev/null +++ b/tests/integration/sdk/test_chat.py @@ -0,0 +1,181 @@ +"""Tests for the chat endpoint — non-streaming, streaming, conversations, and RAG.""" + +import os +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestChat: + """Core chat operation tests.""" + + @pytest.mark.asyncio + async def test_chat_non_streaming(self, client): + """Non-streaming chat returns a non-empty response string.""" + response = await client.chat.create(message="Say hello in exactly 3 words.") + + assert response.response is not None + assert isinstance(response.response, str) + assert len(response.response) > 0 + + @pytest.mark.asyncio + async def test_chat_streaming_create(self, client): + """create(stream=True) yields content events with text deltas.""" + collected_text = "" + + async for event in await client.chat.create( + message="Say 'test' and nothing else.", + stream=True, + ): + if event.type == "content": + collected_text += event.delta + + assert len(collected_text) > 0 + + @pytest.mark.asyncio + async def test_chat_streaming_context_manager(self, client): + """stream() context manager accumulates text in stream.text.""" + async with client.chat.stream(message="Say 'hello' and nothing else.") as stream: + async for _ in stream: + pass + assert len(stream.text) > 0 + + @pytest.mark.asyncio + async def test_chat_text_stream(self, client): + """text_stream yields plain text deltas.""" + collected = "" + + async with client.chat.stream(message="Say 'world' and nothing else.") as stream: + async for text in stream.text_stream: + collected += text + + assert len(collected) > 0 + + @pytest.mark.asyncio + async def test_chat_final_text(self, client): + """final_text() returns the complete accumulated response.""" + async with client.chat.stream(message="Say 'done' and nothing else.") as stream: + text = await stream.final_text() + + assert len(text) > 0 + + @pytest.mark.asyncio + async def test_chat_conversation_continuation(self, client): + """A second message with chat_id continues the same conversation.""" + response1 = await client.chat.create(message="Remember the number 42.") + assert response1.chat_id is not None + + response2 = await client.chat.create( + message="What number did I ask you to remember?", + chat_id=response1.chat_id, + ) + assert response2.response is not None + + @pytest.mark.asyncio + async def test_list_conversations(self, client): + """list() returns a ConversationListResponse with a list of conversations.""" + await client.chat.create(message="Test message for listing.") + + result = await client.chat.list() + + assert result.conversations is not None + assert isinstance(result.conversations, list) + + @pytest.mark.asyncio + async def test_get_conversation(self, client): + """get() returns the full conversation with message history.""" + response = await client.chat.create(message="Test message for get.") + assert response.chat_id is not None + + conversation = await client.chat.get(response.chat_id) + + assert conversation.chat_id == response.chat_id + assert conversation.messages is not None + assert isinstance(conversation.messages, list) + assert len(conversation.messages) >= 1 + + @pytest.mark.asyncio + async def test_delete_conversation(self, client): + """delete() returns True for a conversation that exists.""" + response = await client.chat.create(message="Test message for delete.") + assert response.chat_id is not None + + result = await client.chat.delete(response.chat_id) + + assert result is True + + @pytest.mark.asyncio + async def test_chat_with_sources(self, client, test_file: Path): + """Chat response must cite the ingested document as a source (RAG).""" + result = await client.documents.ingest(file_path=str(test_file)) + if result.status == "failed" or result.successful_files == 0: + pytest.skip("Document ingestion failed — cannot test RAG sources") + + response = await client.chat.create( + message="What is the color of the dancing animals mentioned in my documents?" + ) + + assert response.sources is not None + assert len(response.sources) > 0 + source_filenames = [s.filename for s in response.sources] + assert any(test_file.name in name for name in source_filenames) + + +class TestChatExtended: + """Additional chat edge-case tests.""" + + @pytest.mark.asyncio + async def test_stream_continuation_with_chat_id(self, client): + """Streaming a follow-up message in an existing conversation works.""" + r1 = await client.chat.create(message="Remember the colour blue.") + assert r1.chat_id is not None + + collected = "" + async with client.chat.stream( + message="What colour did I ask you to remember?", + chat_id=r1.chat_id, + ) as stream: + async for text in stream.text_stream: + collected += text + + assert len(collected) > 0 + await client.chat.delete(r1.chat_id) + + @pytest.mark.asyncio + async def test_chat_response_has_chat_id(self, client): + """Every non-streaming response must include a chat_id for continuation.""" + response = await client.chat.create(message="Hello.") + assert response.chat_id is not None + assert isinstance(response.chat_id, str) + assert len(response.chat_id) > 0 + await client.chat.delete(response.chat_id) + + @pytest.mark.asyncio + async def test_stream_chat_id_available_after_iteration(self, client): + """chat_id must be populated on ChatStream after the stream is consumed.""" + async with client.chat.stream(message="Say one word.") as stream: + await stream.final_text() + assert stream.chat_id is not None + + @pytest.mark.asyncio + async def test_chat_sources_field_is_list(self, client): + """sources on ChatResponse is always a list (may be empty).""" + response = await client.chat.create(message="What time is it?") + assert response.sources is not None + assert isinstance(response.sources, list) + if response.chat_id: + await client.chat.delete(response.chat_id) + + @pytest.mark.asyncio + async def test_list_conversations_returns_list(self, client): + """list() always returns a ConversationListResponse with a list.""" + r = await client.chat.create(message="List test message.") + result = await client.chat.list() + assert isinstance(result.conversations, list) + if r.chat_id: + await client.chat.delete(r.chat_id) diff --git a/tests/integration/sdk/test_documents.py b/tests/integration/sdk/test_documents.py new file mode 100644 index 000000000..9fcab83ab --- /dev/null +++ b/tests/integration/sdk/test_documents.py @@ -0,0 +1,121 @@ +"""Tests for document ingestion and deletion.""" + +import io +import os +import uuid +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestDocuments: + """Core document ingestion and deletion tests.""" + + @pytest.mark.asyncio + async def test_ingest_document_no_wait(self, client, test_file: Path): + """wait=False returns a task_id immediately; polling reaches a terminal state.""" + result = await client.documents.ingest(file_path=str(test_file), wait=False) + assert result.task_id is not None + + final_status = await client.documents.wait_for_task(result.task_id) + assert final_status.status is not None + assert final_status.successful_files >= 0 + + @pytest.mark.asyncio + async def test_ingest_document(self, client, test_file: Path): + """wait=True polls until completion and returns a terminal status.""" + result = await client.documents.ingest(file_path=str(test_file)) + assert result.status is not None + assert result.successful_files >= 0 + + @pytest.mark.asyncio + async def test_delete_document(self, client, test_file: Path): + """Deleting an ingested document succeeds when chunks were indexed.""" + ingest_result = await client.documents.ingest(file_path=str(test_file)) + + result = await client.documents.delete(test_file.name) + + if ingest_result.successful_files > 0: + assert result.success is True + assert result.deleted_chunks > 0 + else: + assert result.success is False + assert result.deleted_chunks == 0 + + @pytest.mark.asyncio + async def test_delete_missing_document_is_idempotent(self, client): + """Deleting a never-ingested filename must not raise.""" + missing_filename = f"never_ingested_{uuid.uuid4().hex}.pdf" + result = await client.documents.delete(missing_filename) + + assert result.success is False + assert result.deleted_chunks == 0 + assert result.filename == missing_filename + assert result.error is not None + + +class TestDocumentsExtended: + """Additional document ingestion scenarios.""" + + @pytest.mark.asyncio + async def test_ingest_via_file_object(self, client): + """Ingest using a file-like object (io.BytesIO) instead of a file path.""" + unique_token = uuid.uuid4().hex + content = ( + f"# File Object Test\n\n" + f"Token: {unique_token}\n\n" + f"This document was ingested via a file object.\n" + ).encode() + + filename = f"file_obj_{unique_token[:8]}.md" + result = await client.documents.ingest(file=io.BytesIO(content), filename=filename) + assert result.status is not None + assert result.successful_files >= 0 + + await client.documents.delete(filename) + + @pytest.mark.asyncio + async def test_reingest_same_filename_does_not_raise(self, client, tmp_path): + """Ingesting the same filename twice must not raise an error.""" + unique_token = uuid.uuid4().hex + file_path = tmp_path / f"reingest_{unique_token[:8]}.md" + file_path.write_text(f"# Reingest Test\n\nToken: {unique_token}\n") + + result1 = await client.documents.ingest(file_path=str(file_path)) + assert result1.status is not None + + result2 = await client.documents.ingest(file_path=str(file_path)) + assert result2.status is not None + + await client.documents.delete(file_path.name) + + @pytest.mark.asyncio + async def test_ingest_markdown_format(self, client, tmp_path): + """Verify .md files are accepted and processed without error.""" + file_path = tmp_path / f"format_md_{uuid.uuid4().hex[:8]}.md" + file_path.write_text("# Markdown Format\n\n## Section\n\nContent here.\n") + result = await client.documents.ingest(file_path=str(file_path)) + assert result.status is not None + await client.documents.delete(file_path.name) + + @pytest.mark.asyncio + async def test_task_status_polling(self, client, tmp_path): + """wait=False returns a task_id that can be polled and waited on manually.""" + file_path = tmp_path / f"poll_{uuid.uuid4().hex[:8]}.md" + file_path.write_text("# Polling Test\n\nContent for polling test.\n") + + task_response = await client.documents.ingest(file_path=str(file_path), wait=False) + assert task_response.task_id is not None + + status = await client.documents.get_task_status(task_response.task_id) + assert status.status is not None + + final = await client.documents.wait_for_task(task_response.task_id) + assert final.status in ("completed", "failed") + + await client.documents.delete(file_path.name) diff --git a/tests/integration/sdk/test_e2e.py b/tests/integration/sdk/test_e2e.py new file mode 100644 index 000000000..2b13ad45f --- /dev/null +++ b/tests/integration/sdk/test_e2e.py @@ -0,0 +1,102 @@ +"""End-to-end tests covering full multi-step SDK workflows.""" + +import os +import uuid + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestEndToEnd: + """Full pipeline tests that exercise multiple SDK operations together.""" + + @pytest.mark.asyncio + async def test_full_rag_pipeline(self, client, tmp_path): + """Ingest → search → chat: source cited in chat must match the ingested doc.""" + unique_token = uuid.uuid4().hex + file_path = tmp_path / f"rag_e2e_{unique_token[:8]}.md" + file_path.write_text( + f"# E2E RAG Test\n\n" + f"Unique token: {unique_token}\n\n" + f"The flamingo named Zephyr lives on planet Xylox-7.\n" + ) + + ingest_result = await client.documents.ingest(file_path=str(file_path)) + if ingest_result.successful_files == 0: + pytest.skip("Ingestion produced no indexed chunks — skipping E2E RAG test") + + search_results = await client.search.query("flamingo Zephyr planet Xylox") + assert search_results.results is not None + + chat_response = await client.chat.create( + message="What is the name of the flamingo and where does it live?" + ) + assert chat_response.response is not None + assert len(chat_response.response) > 0 + assert chat_response.sources is not None + source_names = [s.filename for s in chat_response.sources] + assert any(file_path.name in name for name in source_names), ( + f"Expected {file_path.name} in sources {source_names}" + ) + + await client.documents.delete(file_path.name) + + @pytest.mark.asyncio + async def test_multiturn_rag_conversation(self, client, tmp_path): + """Multi-turn conversation: context carries across turns.""" + unique_token = uuid.uuid4().hex + file_path = tmp_path / f"multiturn_{unique_token[:8]}.md" + file_path.write_text( + f"# Multiturn Test\n\n" + f"The capital of the fictional country Valdoria is Sunhaven.\n" + f"Token: {unique_token}\n" + ) + + ingest_result = await client.documents.ingest(file_path=str(file_path)) + if ingest_result.successful_files == 0: + pytest.skip("Ingestion produced no indexed chunks — skipping multiturn test") + + r1 = await client.chat.create(message="What is the capital of Valdoria?") + assert r1.chat_id is not None + + r2 = await client.chat.create( + message="Repeat the capital city you just mentioned.", + chat_id=r1.chat_id, + ) + assert r2.response is not None + assert r2.chat_id == r1.chat_id + + await client.documents.delete(file_path.name) + await client.chat.delete(r1.chat_id) + + @pytest.mark.asyncio + async def test_knowledge_filter_scopes_search_results(self, client): + """A knowledge filter must constrain search and chat to its configured scope.""" + unique_token = uuid.uuid4().hex + create_result = await client.knowledge_filters.create({ + "name": f"E2E Scope Filter {unique_token[:8]}", + "description": "Filter for E2E scoping test", + "queryData": { + "query": f"scoped content {unique_token}", + "limit": 5, + "scoreThreshold": 0.0, + }, + }) + assert create_result.success is True + filter_id = create_result.id + + try: + search_results = await client.search.query("test query", filter_id=filter_id) + assert search_results.results is not None + + chat_response = await client.chat.create( + message="Summarise what you know.", + filter_id=filter_id, + ) + assert chat_response.response is not None + finally: + await client.knowledge_filters.delete(filter_id) diff --git a/tests/integration/sdk/test_errors.py b/tests/integration/sdk/test_errors.py new file mode 100644 index 000000000..511cc7f95 --- /dev/null +++ b/tests/integration/sdk/test_errors.py @@ -0,0 +1,68 @@ +"""Tests for SDK error handling and propagation.""" + +import io +import os +import uuid + +import pytest + +from .conftest import _base_url + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestErrorHandling: + """Verify the SDK surfaces errors correctly rather than swallowing them.""" + + @pytest.mark.asyncio + async def test_connection_refused_raises_exception(self): + """Pointing the client at a dead port must raise a network exception, not hang.""" + from openrag_sdk import OpenRAGClient + + dead_client = OpenRAGClient( + api_key="orag_test", + base_url="http://localhost:19999", + timeout=3.0, + ) + try: + with pytest.raises(Exception): + await dead_client.settings.get() + finally: + await dead_client.close() + + @pytest.mark.asyncio + async def test_get_nonexistent_conversation_raises_not_found(self, client): + """Fetching a conversation with a random UUID must raise NotFoundError.""" + from openrag_sdk.exceptions import NotFoundError + + with pytest.raises(NotFoundError): + await client.chat.get(str(uuid.uuid4())) + + @pytest.mark.asyncio + async def test_delete_nonexistent_conversation_returns_false(self, client): + """Deleting a conversation that never existed must return False.""" + result = await client.chat.delete(str(uuid.uuid4())) + assert result is False + + @pytest.mark.asyncio + async def test_invalid_settings_value_raises_error(self, client): + """Sending an invalid settings value must raise a subclass of OpenRAGError.""" + from openrag_sdk.exceptions import OpenRAGError + + with pytest.raises(OpenRAGError): + await client.settings.update({"chunk_size": -999999}) + + @pytest.mark.asyncio + async def test_ingest_without_file_raises_value_error(self, client): + """Calling ingest() with neither file_path nor file must raise ValueError.""" + with pytest.raises(ValueError): + await client.documents.ingest() + + @pytest.mark.asyncio + async def test_ingest_file_object_without_filename_raises_value_error(self, client): + """Providing a file object without a filename must raise ValueError.""" + with pytest.raises(ValueError): + await client.documents.ingest(file=io.BytesIO(b"content")) diff --git a/tests/integration/sdk/test_filters.py b/tests/integration/sdk/test_filters.py new file mode 100644 index 000000000..5cab4ba6a --- /dev/null +++ b/tests/integration/sdk/test_filters.py @@ -0,0 +1,95 @@ +"""Tests for knowledge filter CRUD and usage in chat/search.""" + +import os + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestKnowledgeFilters: + """Test knowledge filter create, read, update, delete and usage.""" + + @pytest.mark.asyncio + async def test_knowledge_filter_crud(self, client): + """Full CRUD lifecycle for a knowledge filter.""" + create_result = await client.knowledge_filters.create({ + "name": "Python SDK Test Filter", + "description": "Filter created by Python SDK integration tests", + "queryData": { + "query": "test documents", + "limit": 10, + "scoreThreshold": 0.5, + }, + }) + assert create_result.success is True + assert create_result.id is not None + filter_id = create_result.id + + # Search + filters = await client.knowledge_filters.search("Python SDK Test") + assert isinstance(filters, list) + assert any(f.name == "Python SDK Test Filter" for f in filters) + + # Get + filter_obj = await client.knowledge_filters.get(filter_id) + assert filter_obj is not None + assert filter_obj.id == filter_id + assert filter_obj.name == "Python SDK Test Filter" + + # Update + update_success = await client.knowledge_filters.update( + filter_id, + {"description": "Updated description from Python SDK test"}, + ) + assert update_success is True + + updated_filter = await client.knowledge_filters.get(filter_id) + assert updated_filter.description == "Updated description from Python SDK test" + + # Delete + delete_success = await client.knowledge_filters.delete(filter_id) + assert delete_success is True + + deleted_filter = await client.knowledge_filters.get(filter_id) + assert deleted_filter is None + + @pytest.mark.asyncio + async def test_filter_id_in_chat(self, client): + """A filter_id can be passed to chat without error.""" + create_result = await client.knowledge_filters.create({ + "name": "Chat Test Filter Python", + "description": "Filter for testing chat with filter_id", + "queryData": {"query": "test", "limit": 5}, + }) + assert create_result.success is True + filter_id = create_result.id + + try: + response = await client.chat.create( + message="Hello with filter", + filter_id=filter_id, + ) + assert response.response is not None + finally: + await client.knowledge_filters.delete(filter_id) + + @pytest.mark.asyncio + async def test_filter_id_in_search(self, client): + """A filter_id can be passed to search without error.""" + create_result = await client.knowledge_filters.create({ + "name": "Search Test Filter Python", + "description": "Filter for testing search with filter_id", + "queryData": {"query": "test", "limit": 5}, + }) + assert create_result.success is True + filter_id = create_result.id + + try: + results = await client.search.query("test query", filter_id=filter_id) + assert results.results is not None + finally: + await client.knowledge_filters.delete(filter_id) diff --git a/tests/integration/sdk/test_models.py b/tests/integration/sdk/test_models.py new file mode 100644 index 000000000..c86366116 --- /dev/null +++ b/tests/integration/sdk/test_models.py @@ -0,0 +1,24 @@ +"""Tests for the models endpoint.""" + +import os + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestModels: + """Test model listing per provider.""" + + @pytest.mark.asyncio + async def test_list_models(self, client): + """Listing models for a provider must return language and embedding model lists.""" + models = await client.models.list("openai") + + assert models.language_models is not None + assert isinstance(models.language_models, list) + assert models.embedding_models is not None + assert isinstance(models.embedding_models, list) diff --git a/tests/integration/sdk/test_search.py b/tests/integration/sdk/test_search.py new file mode 100644 index 000000000..c917f335f --- /dev/null +++ b/tests/integration/sdk/test_search.py @@ -0,0 +1,71 @@ +"""Tests for the search endpoint.""" + +import os +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestSearch: + """Core search query tests.""" + + @pytest.mark.asyncio + async def test_search_query(self, client, test_file: Path): + """A basic search query returns a results list.""" + await client.documents.ingest(file_path=str(test_file)) + + results = await client.search.query("purple elephants dancing") + assert results.results is not None + + +class TestSearchExtended: + """Additional search parameter and edge-case tests.""" + + @pytest.mark.asyncio + async def test_search_with_limit(self, client, test_file: Path): + """limit parameter caps the number of results returned.""" + await client.documents.ingest(file_path=str(test_file)) + + results = await client.search.query("test", limit=1) + assert results.results is not None + assert len(results.results) <= 1 + + @pytest.mark.asyncio + async def test_search_with_high_score_threshold_returns_empty(self, client, test_file: Path): + """A score_threshold of 0.99 should filter out most or all results.""" + await client.documents.ingest(file_path=str(test_file)) + + results = await client.search.query("test", score_threshold=0.99) + assert results.results is not None + assert isinstance(results.results, list) + + @pytest.mark.asyncio + async def test_search_no_results_for_obscure_query(self, client): + """A nonsense query must return an empty list, not raise an error.""" + results = await client.search.query( + "zzz_xyzzy_nonexistent_content_abc123_qwerty_999" + ) + assert results.results is not None + assert isinstance(results.results, list) + + @pytest.mark.asyncio + async def test_search_unicode_query(self, client): + """Unicode and emoji characters in the query must not cause an error.""" + results = await client.search.query("こんにちは 🦩 Ñoño résumé") + assert results.results is not None + assert isinstance(results.results, list) + + @pytest.mark.asyncio + async def test_search_returns_result_fields(self, client, test_file: Path): + """Each search result must have text populated as a string.""" + await client.documents.ingest(file_path=str(test_file)) + + results = await client.search.query("purple elephants dancing", limit=5) + for result in results.results: + assert result.text is not None + assert isinstance(result.text, str) diff --git a/tests/integration/sdk/test_settings.py b/tests/integration/sdk/test_settings.py new file mode 100644 index 000000000..0f28b12d0 --- /dev/null +++ b/tests/integration/sdk/test_settings.py @@ -0,0 +1,34 @@ +"""Tests for the settings endpoint.""" + +import os + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("SKIP_SDK_INTEGRATION_TESTS") == "true", + reason="SDK integration tests skipped", +) + + +class TestSettings: + """Test settings get and update operations.""" + + @pytest.mark.asyncio + async def test_get_settings(self, client): + """Settings response must include agent and knowledge sections.""" + settings = await client.settings.get() + + assert settings.agent is not None + assert settings.knowledge is not None + + @pytest.mark.asyncio + async def test_update_settings(self, client): + """Updating a setting must persist and be readable back.""" + current_settings = await client.settings.get() + current_chunk_size = current_settings.knowledge.chunk_size or 1000 + + result = await client.settings.update({"chunk_size": current_chunk_size}) + assert result.message is not None + + updated_settings = await client.settings.get() + assert updated_settings.knowledge.chunk_size == current_chunk_size diff --git a/uv.lock b/uv.lock index b1a605db4..f523b6195 100644 --- a/uv.lock +++ b/uv.lock @@ -819,6 +819,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/ae/2f6d96b4e6c5478d87d606a1934b5d436c4a2bce6bb7c6fdece891c128e3/huggingface_hub-1.4.1-py3-none-any.whl", hash = "sha256:9931d075fb7a79af5abc487106414ec5fba2c0ae86104c0c62fd6cae38873d18", size = 553326, upload-time = "2026-02-06T09:20:00.728Z" }, ] +[[package]] +name = "ibm-cos-sdk" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ibm-cos-sdk-core" }, + { name = "ibm-cos-sdk-s3transfer" }, + { name = "jmespath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d8/6c/ffbe556fd148e58d5b8e4b1f71fce604dcf531d791a8599cc8a7a1ee2a09/ibm_cos_sdk-2.16.0.tar.gz", hash = "sha256:ef11ceb121dc5c90050e87a82b6d27e67394c58f1c7abcc991a3d9a0964a290a", size = 58939, upload-time = "2026-01-06T10:30:11.235Z" } + +[[package]] +name = "ibm-cos-sdk-core" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/bb/4d6908fd4fe9d1994ad6aae483e352859b4ae7e501b8743f5db5c4840124/ibm_cos_sdk_core-2.16.0.tar.gz", hash = "sha256:2707d0ca62dd6f85455c4ac02ed37aa35ea1f2b143666271237f979d4fe904f7", size = 1119533, upload-time = "2026-01-06T10:30:01.931Z" } + +[[package]] +name = "ibm-cos-sdk-s3transfer" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ibm-cos-sdk-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/6f/660a6e0b7015a512304487c17c4ac87c280180b0a6142ee196570efc221e/ibm_cos_sdk_s3transfer-2.16.0.tar.gz", hash = "sha256:5cca69c48dcb7a1442b39cadb3a635162fbc2cf002820c8cd6cb170056d27c5c", size = 141138, upload-time = "2026-01-06T10:30:06.375Z" } + [[package]] name = "idna" version = "3.11" @@ -914,11 +946,11 @@ wheels = [ [[package]] name = "jmespath" -version = "1.1.0" +version = "1.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, + { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, ] [[package]] @@ -1265,6 +1297,7 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-auth-oauthlib" }, { name = "httpx" }, + { name = "ibm-cos-sdk" }, { name = "msal" }, { name = "opensearch-py", extra = ["async"] }, { name = "psutil" }, @@ -1298,6 +1331,7 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = ">=0.2.0" }, { name = "google-auth-oauthlib", specifier = ">=1.2.0" }, { name = "httpx", specifier = ">=0.27.0" }, + { name = "ibm-cos-sdk", specifier = ">=2.13.0" }, { name = "msal", specifier = ">=1.29.0" }, { name = "opensearch-py", extras = ["async"], specifier = ">=3.0.0" }, { name = "psutil", specifier = ">=7.0.0" },