From e6e219eddeb0f447ca75486b74e7d7628409f92f Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 11:08:12 +0530 Subject: [PATCH 01/18] Add GitHub Actions workflow and ECS task definition Introduces a GitHub Actions workflow for building and deploying the application to AWS ECS on pushes to the dev branch. Adds an ECS task definition JSON for the knowflow-backend service, including container configuration, secrets, health checks, and required roles. --- .github/workflows/deploy-to-ecs.yml | 60 ++++++++++++++++++ ecs/task-definition.json | 94 +++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 .github/workflows/deploy-to-ecs.yml create mode 100644 ecs/task-definition.json diff --git a/.github/workflows/deploy-to-ecs.yml b/.github/workflows/deploy-to-ecs.yml new file mode 100644 index 0000000..fd2fe56 --- /dev/null +++ b/.github/workflows/deploy-to-ecs.yml @@ -0,0 +1,60 @@ +name: Build & Deploy to ECS + +on: + push: + branches: + - dev + +env: + AWS_REGION: ${{ vars.AWS_REGION }} + ECR_REPOSITORY: ${{ vars.ECR_REPOSITORY }} + ECS_CLUSTER: ${{ vars.ECS_CLUSTER }} + ECS_SERVICE: ${{ vars.ECS_SERVICE }} + CONTAINER_NAME: ${{ vars.CONTAINER_NAME }} + ECS_TASK_DEF: ${{ vars.ECS_TASK_DEF }} + +jobs: + deploy: + runs-on: ubuntu-latest + environment: production + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v1 + + - name: Build & Push Docker Image + id: build-image + env: + ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} + IMAGE_TAG: ${{ github.sha }} + run: | + docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . + docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG + echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT + + - name: Render Updated ECS Task Definition + id: render-task-def + uses: aws-actions/amazon-ecs-render-task-definition@v1 + with: + task-definition: ${{ env.ECS_TASK_DEF }} + container-name: ${{ env.CONTAINER_NAME }} + image: ${{ steps.build-image.outputs.image }} + + - name: Deploy to Amazon ECS + uses: aws-actions/amazon-ecs-deploy-task-definition@v1 + with: + task-definition: ${{ steps.render-task-def.outputs.task-definition }} + service: ${{ env.ECS_SERVICE }} + cluster: ${{ env.ECS_CLUSTER }} + wait-for-service-stability: true diff --git a/ecs/task-definition.json b/ecs/task-definition.json new file mode 100644 index 0000000..c40bf2b --- /dev/null +++ b/ecs/task-definition.json @@ -0,0 +1,94 @@ +{ + "taskDefinition": { + "taskDefinitionArn": "arn:aws:ecs:ap-south-1:953685791553:task-definition/knowflow-backend-task:1", + "containerDefinitions": [ + { + "name": "knowflow-backend", + "image": "953685791553.dkr.ecr.ap-south-1.amazonaws.com/knowflow-backend:latest", + "cpu": 0, + "portMappings": [ + { + "containerPort": 8000, + "hostPort": 8000, + "protocol": "tcp", + "name": "knowflow-backend-8000-tcp", + "appProtocol": "http" + } + ], + "essential": true, + "secrets": [ + { + "name": "AWS_ACCESS_KEY_ID", + "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_ACCESS_KEY_ID::" + }, + { + "name": "AWS_SECRET_ACCESS_KEY", + "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_SECRET_ACCESS_KEY::" + } + ], + "environmentFiles": [], + "mountPoints": [], + "volumesFrom": [], + "ulimits": [], + "healthCheck": { + "command": [ + "CMD-SHELL", + "curl -f http://localhost:8000/health || exit 1" + ], + "interval": 10, + "timeout": 5, + "retries": 3 + }, + "systemControls": [] + } + ], + "family": "knowflow-backend-task", + "taskRoleArn": "arn:aws:iam::953685791553:role/ECSTaskRole", + "executionRoleArn": "arn:aws:iam::953685791553:role/ecsTaskExecutionRole", + "networkMode": "awsvpc", + "revision": 1, + "volumes": [], + "status": "ACTIVE", + "requiresAttributes": [ + { + "name": "com.amazonaws.ecs.capability.docker-remote-api.1.24" + }, + { + "name": "com.amazonaws.ecs.capability.ecr-auth" + }, + { + "name": "com.amazonaws.ecs.capability.task-iam-role" + }, + { + "name": "ecs.capability.container-health-check" + }, + { + "name": "ecs.capability.execution-role-ecr-pull" + }, + { + "name": "com.amazonaws.ecs.capability.docker-remote-api.1.18" + }, + { + "name": "ecs.capability.task-eni" + } + ], + "placementConstraints": [], + "compatibilities": [ + "EC2", + "FARGATE" + ], + "runtimePlatform": { + "cpuArchitecture": "X86_64", + "operatingSystemFamily": "LINUX" + }, + "requiresCompatibilities": [ + "FARGATE" + ], + "cpu": "2048", + "memory": "8192", + "registeredAt": "2025-07-20T08:14:39.290000+05:30", + "registeredBy": "arn:aws:iam::953685791553:root", + "enableFaultInjection": false + }, + "tags": [] +} \ No newline at end of file From 1fc68b45a5cd59a452b5136023e02483abac20af Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 11:11:12 +0530 Subject: [PATCH 02/18] Update ECS deployment workflow to specify Dockerfile location Modified the GitHub Actions workflow for ECS deployment to explicitly use the backend Dockerfile located in the docker directory, ensuring the correct image is built for deployment. --- .github/workflows/deploy-to-ecs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy-to-ecs.yml b/.github/workflows/deploy-to-ecs.yml index fd2fe56..6ba7c62 100644 --- a/.github/workflows/deploy-to-ecs.yml +++ b/.github/workflows/deploy-to-ecs.yml @@ -39,7 +39,7 @@ jobs: ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} IMAGE_TAG: ${{ github.sha }} run: | - docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . + docker build -f docker/Dockerfile.backend -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT From 5f7e703b8818d145316860fd3c5b87b58779d5eb Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 11:14:47 +0530 Subject: [PATCH 03/18] Update Dockerfile and configuration settings --- docker/Dockerfile.backend | 2 +- src/core/config.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 3cda427..e4d8064 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -13,7 +13,7 @@ RUN sh /uv-installer.sh && rm /uv-installer.sh ENV PATH="/root/.local/bin/:$PATH" -COPY .env requirements.txt pyproject.toml uv.lock .python-version ./ +COPY requirements.txt pyproject.toml uv.lock .python-version ./ RUN uv sync --locked diff --git a/src/core/config.py b/src/core/config.py index 5a59fad..c3c33f5 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -65,12 +65,14 @@ class Settings(BaseSettings): # AI Models GOOGLE_API_KEY: str = Field(default="") - GEMINI_EMBEDDING_MODEL: str = Field(default="", env="GEMINI_EMBEDDING_MODEL") - GEMINI_MODEL_NAME: str = Field(default="", env="GEMINI_MODEL_NAME") + GEMINI_EMBEDDING_MODEL: str = Field( + default="models/embedding-001", env="GEMINI_EMBEDDING_MODEL" + ) + GEMINI_MODEL_NAME: str = Field(default="gemini-2.0-flash", env="GEMINI_MODEL_NAME") # Vector Store VECTOR_COLLECTION_NAME: str = Field( - default="knowflow_vecotr_db", env="VECTOR_COLLECTION_NAME" + default="knowflow_vector_db", env="VECTOR_COLLECTION_NAME" ) CHUNK_SIZE: int = Field(default=700) CHUNK_OVERLAP: int = Field(default=80) From 9ac6f9d5a9be30daa5a2a3f75b0e0c5b0da3f7f0 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 11:20:44 +0530 Subject: [PATCH 04/18] Update ECS deployment workflow to use task definition family instead of task definition --- .github/workflows/deploy-to-ecs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy-to-ecs.yml b/.github/workflows/deploy-to-ecs.yml index 6ba7c62..ebcdfcd 100644 --- a/.github/workflows/deploy-to-ecs.yml +++ b/.github/workflows/deploy-to-ecs.yml @@ -47,7 +47,7 @@ jobs: id: render-task-def uses: aws-actions/amazon-ecs-render-task-definition@v1 with: - task-definition: ${{ env.ECS_TASK_DEF }} + task-definition-family: ${{ env.ECS_TASK_DEF }} container-name: ${{ env.CONTAINER_NAME }} image: ${{ steps.build-image.outputs.image }} From fa50364812542a6dd15b913161d8bf25270bd2b4 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 17:41:41 +0530 Subject: [PATCH 05/18] Update Dockerfile for Neo4j to set default listen address --- docker/Dockerfile.neo4j | 1 + docs/FLOWCHART.md | 200 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 docs/FLOWCHART.md diff --git a/docker/Dockerfile.neo4j b/docker/Dockerfile.neo4j index bf72468..be2abd5 100644 --- a/docker/Dockerfile.neo4j +++ b/docker/Dockerfile.neo4j @@ -3,6 +3,7 @@ FROM neo4j:5.15.0 ENV NEO4J_AUTH=neo4j/Pstm!tr0ae#123 ENV NEO4J_PLUGINS=["apoc"] ENV NEO4J_dbms_security_procedures_unrestricted=apoc.* +ENV NEO4J_dbms_default__listen__address=0.0.0.0 ENV NEO4J_dbms_memory_heap_initial__size=512m ENV NEO4J_dbms_memory_heap_max__size=2G diff --git a/docs/FLOWCHART.md b/docs/FLOWCHART.md new file mode 100644 index 0000000..2461b0d --- /dev/null +++ b/docs/FLOWCHART.md @@ -0,0 +1,200 @@ +# KnowFlow System Architecture + +## Overview + +KnowFlow is an AI-powered knowledge management system that combines multiple advanced components to provide intelligent document processing, chat interactions, and knowledge retrieval. + +## System Components + +### 1. Client Layer + +- **Frontend UI**: User interface for interacting with the system + +### 2. Routes Layer + +- **Auth Routes**: Authentication and user management endpoints +- **Chat Routes**: Chat interaction endpoints +- **Document Routes**: Document processing endpoints +- **Session Routes**: Chat session management endpoints + +### 3. Core Services + +- **Auth Service**: User authentication and authorization +- **Chat Service**: Main chat processing and orchestration +- **Document Service**: Document processing and indexing +- **Session Service**: Chat session management +- **Graph Service**: Knowledge graph operations +- **S3 Service**: File storage management + +### 4. AI Components + +#### LLM Services + +- **Base LLM Client**: Gemini Pro integration for natural language processing +- **Query Decomposition**: Breaks complex queries into manageable sub-questions +- **Retrieval Evaluation**: Evaluates and improves retrieval quality + +#### Vector Store + +- **PGVector Store**: PostgreSQL-based vector storage for semantic search +- **Embeddings**: Gemini Embedding model for document vectorization + +#### Knowledge Graph + +- **Neo4j Graph DB**: Graph database for structured knowledge +- **Knowledge Extraction**: Converts text to graph structures + +### 5. External Services + +- **Google Cloud (Gemini API)**: Powers LLM and embedding operations +- **AWS S3**: Document storage + +## Key Flows + +### 1. Document Processing Flow + +1. Documents uploaded through Document Routes +2. Document Service processes them +3. Raw content stored in S3 +4. Embeddings generated and stored in PGVector +5. Knowledge extracted and stored in Neo4j + +### 2. Chat Flow + +1. User query received via Chat Routes +2. Chat Service orchestrates: + - Query decomposition for complex queries + - Vector search for relevant content + - Knowledge graph querying + - Retrieval evaluation and improvement + - LLM response generation + +### 3. AI Processing Flow + +- Base LLM Client manages all Gemini API interactions +- Query Decomposition handles complex queries +- Retrieval Evaluation ensures response quality +- Knowledge Graph maintains structured information +- Vector Store enables semantic search + +## Mermaid Diagram + +```mermaid +flowchart TB + subgraph Client + UI["Frontend UI"] + end + + subgraph Routes + AR["Auth Routes"] + CR["Chat Routes"] + DR["Document Routes"] + SR["Session Routes"] + end + + subgraph Core_Services + AS["Auth Service"] + CS["Chat Service"] + DS["Document Service"] + SS["Session Service"] + GS["Graph Service"] + S3["S3 Service"] + end + + subgraph AI_Components + subgraph LLM["LLM Services"] + BC["Base LLM Client
(Gemini Pro)"] + QD["Query Decomposition
Service"] + RE["Retrieval Evaluation
Service"] + end + + subgraph Vector_Store + PV["PGVector Store
(PostgreSQL)"] + EM["Embeddings
(Gemini Embedding)"] + end + + subgraph Knowledge_Graph + Neo["Neo4j Graph DB"] + KE["Knowledge Extraction"] + end + end + + subgraph External_Services + GCP["Google Cloud
(Gemini API)"] + AWS["AWS S3"] + end + + %% Client to Routes + UI --> AR + UI --> CR + UI --> DR + UI --> SR + + %% Routes to Services + AR --> AS + CR --> CS + DR --> DS + SR --> SS + + %% Core Service Dependencies + CS --> BC + CS --> QD + CS --> RE + CS --> PV + CS --> GS + CS --> SS + + DS --> S3 + DS --> PV + DS --> GS + DS --> EM + + %% AI Component Interactions + BC --> GCP + QD --> BC + RE --> BC + EM --> GCP + GS --> Neo + GS --> BC + GS --> KE + KE --> BC + + %% Storage Connections + S3 --> AWS + PV --> EM + + %% Data Flow + DS --> PV + DS --> KE + CS --> PV + CS --> Neo +``` + +## System Features + +1. **Intelligent Document Processing** + + - Automatic content extraction + - Semantic embedding generation + - Knowledge graph construction + - Structured storage + +2. **Advanced Query Processing** + + - Query decomposition for complex questions + - Multi-source information retrieval + - Quality evaluation and improvement + - Context-aware responses + +3. **Knowledge Management** + + - Semantic search capabilities + - Structured knowledge representation + - Relationship mapping + - Context preservation + +4. **Security & Organization** + - User authentication and authorization + - Session management + - Secure file storage + - Access control From d8b37c5442ea2cdbaabba6c8c5714be63a6f473e Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Sun, 20 Jul 2025 17:44:59 +0530 Subject: [PATCH 06/18] Revise FLOWCHART.md to enhance documentation of KnowFlow's technical architecture. Updated sections include Document Processing Pipeline, Query Processing Pipeline, and Storage Architecture, with detailed descriptions of components, processes, and code snippets for clarity and usability. Improved Mermaid diagram to reflect new architecture and flows. --- docs/FLOWCHART.md | 431 ++++++++++++++++++++++++++-------------------- 1 file changed, 249 insertions(+), 182 deletions(-) diff --git a/docs/FLOWCHART.md b/docs/FLOWCHART.md index 2461b0d..3f3dd00 100644 --- a/docs/FLOWCHART.md +++ b/docs/FLOWCHART.md @@ -1,200 +1,267 @@ -# KnowFlow System Architecture - -## Overview - -KnowFlow is an AI-powered knowledge management system that combines multiple advanced components to provide intelligent document processing, chat interactions, and knowledge retrieval. - -## System Components - -### 1. Client Layer - -- **Frontend UI**: User interface for interacting with the system - -### 2. Routes Layer - -- **Auth Routes**: Authentication and user management endpoints -- **Chat Routes**: Chat interaction endpoints -- **Document Routes**: Document processing endpoints -- **Session Routes**: Chat session management endpoints - -### 3. Core Services - -- **Auth Service**: User authentication and authorization -- **Chat Service**: Main chat processing and orchestration -- **Document Service**: Document processing and indexing -- **Session Service**: Chat session management -- **Graph Service**: Knowledge graph operations -- **S3 Service**: File storage management - -### 4. AI Components - -#### LLM Services - -- **Base LLM Client**: Gemini Pro integration for natural language processing -- **Query Decomposition**: Breaks complex queries into manageable sub-questions -- **Retrieval Evaluation**: Evaluates and improves retrieval quality - -#### Vector Store - -- **PGVector Store**: PostgreSQL-based vector storage for semantic search -- **Embeddings**: Gemini Embedding model for document vectorization - -#### Knowledge Graph - -- **Neo4j Graph DB**: Graph database for structured knowledge -- **Knowledge Extraction**: Converts text to graph structures - -### 5. External Services - -- **Google Cloud (Gemini API)**: Powers LLM and embedding operations -- **AWS S3**: Document storage - -## Key Flows - -### 1. Document Processing Flow - -1. Documents uploaded through Document Routes -2. Document Service processes them -3. Raw content stored in S3 -4. Embeddings generated and stored in PGVector -5. Knowledge extracted and stored in Neo4j - -### 2. Chat Flow - -1. User query received via Chat Routes -2. Chat Service orchestrates: - - Query decomposition for complex queries - - Vector search for relevant content - - Knowledge graph querying - - Retrieval evaluation and improvement - - LLM response generation - -### 3. AI Processing Flow - -- Base LLM Client manages all Gemini API interactions -- Query Decomposition handles complex queries -- Retrieval Evaluation ensures response quality -- Knowledge Graph maintains structured information -- Vector Store enables semantic search - -## Mermaid Diagram +# KnowFlow Technical Architecture + +## Document Processing Pipeline + +### 1. Document Upload & Initial Processing + +- **Supported Formats**: PDF, DOCX, CSV, TXT +- **Document Loaders**: + - PDF: `PyMuPDFLoader` + - DOCX: `Docx2txtLoader` + - CSV: `CSVLoader` + - TXT: `TextLoader` + - Fallback: `UnstructuredFileLoader` + +### 2. Text Processing + +- **Text Splitter**: `RecursiveCharacterTextSplitter` + ```python + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + separators=["\n\n", "\n", " ", ""], + keep_separator=True + ) + ``` + +### 3. Vector Generation + +- **Embedding Model**: Gemini Embedding Model +- **Vector Store**: PGVector (PostgreSQL) + ```python + vector_store = PGVector( + connection=DATABASE_URL, + embeddings=embeddings, + collection_name=VECTOR_COLLECTION_NAME + ) + ``` + +### 4. Knowledge Graph Construction + +- **Database**: Neo4j +- **Node Types**: + - Document: Main document node + - Section: Document sections + - Entity: Named entities + - Concept: Key ideas/terms + - Tag: Categories +- **Relationship Types**: + - CONTAINS: Hierarchy + - RELATED_TO: General connections + - MENTIONS: References + - HAS_TAG: Classifications + +## Query Processing Pipeline + +### 1. Query Analysis + +- **Query Decomposition**: + ```python + messages = [ + SystemMessage(content="Query decomposition prompt"), + HumanMessage(content=query) + ] + sub_questions = llm.invoke(messages) + ``` + +### 2. Retrieval Process + +- **Vector Search**: + ```python + docs_and_scores = vector_store.similarity_search_with_score_by_vector( + embedding=query_embedding, + k=TOP_K_RESULTS, + filter=filter_dict + ) + ``` +- **Graph Search**: + - Cypher query generation + - Pattern matching + - Context retrieval + +### 3. Retrieval Evaluation + +- **Quality Metrics**: + ```json + { + "chunk_scores": [ + {"chunk": "text", "relevance_score": 0-10, "reasoning": "explanation"} + ], + "missing_aspects": ["list of missing information"], + "redundant_information": ["list of redundancies"], + "suggested_improvements": { + "additional_info_needed": ["missing info"], + "alternative_search_terms": ["suggested terms"] + }, + "overall_quality_score": 0-10 + } + ``` + +### 4. Response Generation + +- **Context Assembly**: + + ```python + def _merge_results(vector_results, graph_results): + graph_texts = [] + for result in graph_results: + text = f"Type: {result.get('type')}\n" + text += f"Properties: {result.get('properties')}\n" + text += f"Relationships: {result.get('relationships')}" + graph_texts.append(text) + + all_texts = vector_results[:3] + graph_texts[:3] + return "\n\n".join(all_texts) + ``` + +- **LLM Prompt Structure**: + ```python + system_prompt = f""" + You are a helpful, reasoning assistant. Answer based on: + Context: {context} + Guidelines: + - Rephrase/summarize from context + - Use reasoning to clarify + - No fabrication + - Indicate if answer unavailable + """ + ``` + +## Storage Architecture + +### 1. Vector Store (PostgreSQL) + +- **Schema**: + ```sql + CREATE TABLE document_chunks ( + id SERIAL PRIMARY KEY, + document_id INTEGER, + chunk_index INTEGER, + content TEXT, + embedding_vector vector(768), + metadata JSONB + ); + ``` + +### 2. Graph Database (Neo4j) + +- **Node Properties**: + ```json + { + "id": "unique_id", + "type": "node_type", + "name": "node_name", + "content": "text_content", + "created_at": "timestamp" + } + ``` +- **Relationship Properties**: + ```json + { + "type": "relationship_type", + "context": "relationship_context", + "confidence": 0.0-1.0, + "created_at": "timestamp" + } + ``` + +### 3. Document Storage (S3) + +- **Structure**: + ``` + s3://bucket/ + └── users/ + └── {user_id}/ + └── documents/ + └── {doc_id}.{extension} + ``` + +## Mermaid Technical Flowchart ```mermaid flowchart TB - subgraph Client - UI["Frontend UI"] - end + %% Document Processing Pipeline + subgraph Document_Processing["Document Processing Pipeline"] + direction TB + Upload["Document Upload
Supported: PDF, DOCX, CSV, TXT"] + + subgraph Processing["Document Processing"] + Loader["Document Loaders
PyMuPDF/Docx2txt/CSV/Text"] + Splitter["Text Splitter
RecursiveCharacterTextSplitter
chunk_size=1000, overlap=200"] + VectorGen["Vector Generation
Gemini Embedding Model"] + end - subgraph Routes - AR["Auth Routes"] - CR["Chat Routes"] - DR["Document Routes"] - SR["Session Routes"] - end + subgraph Storage["Storage Layer"] + S3Store["S3 Storage
Raw Documents"] + PGVector["PGVector Store
- document_id
- chunk_index
- embedding_vector
- metadata"] + Neo4j["Neo4j Graph DB
Node Types: Document, Section,
Entity, Concept, Tag
Relations: CONTAINS, RELATED_TO,
MENTIONS, HAS_TAG"] + end - subgraph Core_Services - AS["Auth Service"] - CS["Chat Service"] - DS["Document Service"] - SS["Session Service"] - GS["Graph Service"] - S3["S3 Service"] + Upload --> Loader + Loader --> Splitter + Splitter --> VectorGen + VectorGen --> PGVector + Upload --> S3Store + Splitter --> Neo4j end - subgraph AI_Components - subgraph LLM["LLM Services"] - BC["Base LLM Client
(Gemini Pro)"] - QD["Query Decomposition
Service"] - RE["Retrieval Evaluation
Service"] + %% Query Processing Pipeline + subgraph Query_Processing["Query Processing Pipeline"] + direction TB + Query["User Query"] + + subgraph Query_Analysis["Query Analysis"] + QDecomp["Query Decomposition
Complex Query → Sub-questions
Using Gemini Pro"] + QEmbed["Query Embedding
Gemini Embedding Model"] end - subgraph Vector_Store - PV["PGVector Store
(PostgreSQL)"] - EM["Embeddings
(Gemini Embedding)"] + subgraph Retrieval["Retrieval Layer"] + VecSearch["Vector Search
Similarity Search with Scores
TOP_K=5"] + GraphSearch["Graph Search
Cypher Query Generation
Pattern Matching"] end - subgraph Knowledge_Graph - Neo["Neo4j Graph DB"] - KE["Knowledge Extraction"] + subgraph Evaluation["Retrieval Evaluation"] + RelCheck["Relevance Check
Score: 0-10"] + ImproveSuggestions["Improvement Suggestions
- Missing Aspects
- Alternative Terms"] end - end - subgraph External_Services - GCP["Google Cloud
(Gemini API)"] - AWS["AWS S3"] + Query --> QDecomp + QDecomp --> QEmbed + QEmbed --> VecSearch + Query --> GraphSearch + VecSearch --> RelCheck + GraphSearch --> RelCheck + RelCheck --> ImproveSuggestions + ImproveSuggestions -.-> QEmbed end - %% Client to Routes - UI --> AR - UI --> CR - UI --> DR - UI --> SR - - %% Routes to Services - AR --> AS - CR --> CS - DR --> DS - SR --> SS - - %% Core Service Dependencies - CS --> BC - CS --> QD - CS --> RE - CS --> PV - CS --> GS - CS --> SS - - DS --> S3 - DS --> PV - DS --> GS - DS --> EM - - %% AI Component Interactions - BC --> GCP - QD --> BC - RE --> BC - EM --> GCP - GS --> Neo - GS --> BC - GS --> KE - KE --> BC - - %% Storage Connections - S3 --> AWS - PV --> EM - - %% Data Flow - DS --> PV - DS --> KE - CS --> PV - CS --> Neo -``` - -## System Features - -1. **Intelligent Document Processing** - - - Automatic content extraction - - Semantic embedding generation - - Knowledge graph construction - - Structured storage + %% Response Generation + subgraph Response_Gen["Response Generation"] + direction TB + Context["Context Assembly
Vector + Graph Results"] + LLMPrompt["LLM Prompt Construction
System + Context + Query"] + Response["Response Generation
Gemini Pro"] -2. **Advanced Query Processing** - - - Query decomposition for complex questions - - Multi-source information retrieval - - Quality evaluation and improvement - - Context-aware responses - -3. **Knowledge Management** + Context --> LLMPrompt + LLMPrompt --> Response + end - - Semantic search capabilities - - Structured knowledge representation - - Relationship mapping - - Context preservation + %% Data Connections + PGVector --> VecSearch + Neo4j --> GraphSearch + VecSearch --> Context + GraphSearch --> Context + + %% External Services + subgraph External["External Services"] + direction LR + Gemini["Google Gemini Pro
- Chat Completion
- Embeddings"] + AWS["AWS S3
Document Storage"] + end -4. **Security & Organization** - - User authentication and authorization - - Session management - - Secure file storage - - Access control + %% Service Connections + VectorGen --> Gemini + QEmbed --> Gemini + Response --> Gemini + S3Store --> AWS +``` From e013e34b5c2c2b08514966a7ff0dd56512e07f68 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Sun, 20 Jul 2025 19:26:11 +0530 Subject: [PATCH 07/18] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 6d80416..9039e9b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ ![Gemini_Generated_Image_j5414ij5414ij541](https://github.com/user-attachments/assets/032a9bc6-b0fb-437d-93cd-a9f063220e03) + +diagram-export-7-20-2025-6_03_33-PM +diagram-export-7-20-2025-6_57_03-PM + KnowFlow is a powerful hybrid Retrieval-Augmented Generation (RAG) system that combines semantic search with knowledge graph capabilities for intelligent document processing and querying. ## 🌟 Features From e683f4cd42db29f6d84e086480f172760979d623 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Mon, 21 Jul 2025 06:48:42 +0530 Subject: [PATCH 08/18] Create pylint.yml --- .github/workflows/pylint.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/pylint.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..6e21c90 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') From 7c029dd24b1157e23b19e4a69da8f58b3c1d2e32 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Tue, 22 Jul 2025 08:30:07 +0530 Subject: [PATCH 09/18] Update README.md --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9039e9b..0dc6e15 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,13 @@ -# KNOWFLOW - +# KnowFlow +--- ![Gemini_Generated_Image_j5414ij5414ij541](https://github.com/user-attachments/assets/032a9bc6-b0fb-437d-93cd-a9f063220e03) - +# Deployment +--- diagram-export-7-20-2025-6_03_33-PM + +# AI Infra +--- diagram-export-7-20-2025-6_57_03-PM KnowFlow is a powerful hybrid Retrieval-Augmented Generation (RAG) system that combines semantic search with knowledge graph capabilities for intelligent document processing and querying. From a2af00e7e8815d870fef5116c3e8f033ce195c32 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 08:42:43 +0530 Subject: [PATCH 10/18] code cleanup --- pyproject.toml | 1 + src/core/auth.py | 4 ++-- src/core/config.py | 4 ++-- src/core/database.py | 2 -- src/core/exceptions.py | 2 +- src/core/logging.py | 2 +- src/core/middleware.py | 5 +---- src/main.py | 6 +++--- src/models/database.py | 5 +++-- src/models/graph.py | 2 +- src/models/request.py | 2 +- src/models/response.py | 1 + src/routes/auth_routes.py | 8 +++---- src/routes/chat_routes.py | 7 +++--- src/routes/document_routes.py | 2 +- src/routes/session_routes.py | 7 +++--- src/services/auth_service.py | 2 +- src/services/base_client.py | 16 +++++++++++++- src/services/chat/chat_service.py | 29 +++++-------------------- src/services/document_service.py | 28 ++++++++---------------- src/services/graph_service.py | 12 +++++------ src/services/s3_service.py | 2 +- src/services/session_service.py | 2 +- uv.lock | 36 +++++++++++++++++++++++++++++++ 24 files changed, 102 insertions(+), 85 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7b97a7b..c1995d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,5 @@ dependencies = [ "constructs>=10.4.2", "psycopg2-binary>=2.9.10", "pymupdf>=1.26.3", + "pytest>=8.4.1", ] diff --git a/src/core/auth.py b/src/core/auth.py index 73083d1..d144f00 100644 --- a/src/core/auth.py +++ b/src/core/auth.py @@ -1,11 +1,11 @@ from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import Depends from fastapi.security import OAuth2PasswordBearer from sqlalchemy.orm import Session from src.core.database import get_db -from src.services.auth_service import AuthService from src.models.database import User +from src.services.auth_service import AuthService oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token") diff --git a/src/core/config.py b/src/core/config.py index c3c33f5..275bcaa 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,8 +1,8 @@ +import json +import boto3 from functools import lru_cache from pathlib import Path from typing import Optional, List -import json -import boto3 from botocore.exceptions import ClientError from pydantic import Field from pydantic_settings import BaseSettings diff --git a/src/core/database.py b/src/core/database.py index a22dc1c..62ec7d0 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,9 +1,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, declarative_base -from neo4j import GraphDatabase from src.core.config import settings -from src.core.logging import logger Base = declarative_base() diff --git a/src/core/exceptions.py b/src/core/exceptions.py index d607ad9..0d1efcd 100644 --- a/src/core/exceptions.py +++ b/src/core/exceptions.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, Optional from fastapi import status +from typing import Any, Dict, Optional class AppException(Exception): diff --git a/src/core/logging.py b/src/core/logging.py index 21c2e5d..93f2f54 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,6 +1,6 @@ +import sys import logging import logging.handlers -import sys from pathlib import Path from typing import Optional diff --git a/src/core/middleware.py b/src/core/middleware.py index 0b80f14..cb71bf4 100644 --- a/src/core/middleware.py +++ b/src/core/middleware.py @@ -1,13 +1,10 @@ -from datetime import datetime, timezone import time import uuid -from typing import Callable, Optional - +from datetime import datetime, timezone from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.types import Message from src.core.config import settings from src.core.logging import logger diff --git a/src/main.py b/src/main.py index dd758f2..a2a2fce 100644 --- a/src/main.py +++ b/src/main.py @@ -2,15 +2,15 @@ from fastapi import FastAPI from src.core.config import settings +from src.core.middleware import setup_middleware +from src.core.database import init_db +from src.core.logging import logger from src.routes import ( auth_routes, chat_routes, document_routes, session_routes, ) -from src.core.middleware import setup_middleware -from src.core.database import init_db -from src.core.logging import logger @asynccontextmanager diff --git a/src/models/database.py b/src/models/database.py index fa37ec7..de1b2b6 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -1,9 +1,10 @@ +import enum +import sqlalchemy from datetime import datetime, timezone from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Text, Enum from sqlalchemy.orm import relationship -import enum -import sqlalchemy from sqlalchemy.sql import func + from src.core.database import Base diff --git a/src/models/graph.py b/src/models/graph.py index f306c6d..cba5f4d 100644 --- a/src/models/graph.py +++ b/src/models/graph.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Optional +from typing import List, Optional from pydantic import BaseModel, Field, field_validator diff --git a/src/models/request.py b/src/models/request.py index cd76d38..cacff9d 100644 --- a/src/models/request.py +++ b/src/models/request.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, EmailStr, Field, constr +from pydantic import BaseModel, EmailStr, Field from typing import Optional, Dict, Any, List from datetime import datetime diff --git a/src/models/response.py b/src/models/response.py index bbd2f67..1ee337c 100644 --- a/src/models/response.py +++ b/src/models/response.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field + from src.models.database import DocumentStatus diff --git a/src/routes/auth_routes.py b/src/routes/auth_routes.py index 681ab22..18830b6 100644 --- a/src/routes/auth_routes.py +++ b/src/routes/auth_routes.py @@ -1,11 +1,10 @@ from datetime import timedelta from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from sqlalchemy.orm import Session +from fastapi.security import OAuth2PasswordRequestForm +from src.core.auth import get_current_user, get_auth_service from src.core.config import settings -from src.core.database import get_db from src.models.request import UserLogin, UserRegister from src.models.response import ( TokenResponse, @@ -13,9 +12,8 @@ RegisterResponse, MessageResponse, ) -from src.services.auth_service import AuthService from src.models.database import User -from src.core.auth import get_current_user, get_auth_service +from src.services.auth_service import AuthService router = APIRouter() diff --git a/src/routes/chat_routes.py b/src/routes/chat_routes.py index 3466b8c..632526f 100644 --- a/src/routes/chat_routes.py +++ b/src/routes/chat_routes.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from src.services.chat.chat_service import ChatService + +from src.core.auth import get_current_user +from src.core.database import get_db from src.core.exceptions import ExternalServiceException from src.core.logging import logger from src.models.request import ChatRequest, FollowUpChatRequest, RenameChatRequest @@ -10,9 +12,8 @@ RenameChatResponse, DeleteChatResponse, ) -from src.core.auth import get_current_user from src.models.database import User -from src.core.database import get_db +from src.services.chat.chat_service import ChatService from src.services.session_service import SessionService diff --git a/src/routes/document_routes.py b/src/routes/document_routes.py index d6730ad..b6e3d7f 100644 --- a/src/routes/document_routes.py +++ b/src/routes/document_routes.py @@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, UploadFile, BackgroundTasks, Query from src.core.auth import get_current_user -from src.services.document_service import DocumentService from src.models.database import User from src.models.request import DocumentIndexRequest from src.models.response import ( DocumentIndexResponse, MultiDocumentUploadResponse, ) +from src.services.document_service import DocumentService router = APIRouter() diff --git a/src/routes/session_routes.py b/src/routes/session_routes.py index 3e401f3..1db1dd1 100644 --- a/src/routes/session_routes.py +++ b/src/routes/session_routes.py @@ -1,17 +1,16 @@ -from typing import List, Optional +from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from src.core.database import get_db +from src.core.auth import get_current_user +from src.models.database import User from src.models.request import CreateSessionRequest, SendMessageRequest from src.models.response import ( ChatSessionResponse, ChatSessionListResponse, - MessageResponse, ) from src.services.session_service import SessionService -from src.core.auth import get_current_user -from src.models.database import User router = APIRouter() diff --git a/src/services/auth_service.py b/src/services/auth_service.py index c9d1957..1b5e547 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -7,7 +7,7 @@ from fastapi.security import OAuth2PasswordBearer from src.core.config import settings -from src.models.database import User, Document, DocumentChunk, DocumentStatus +from src.models.database import User pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") diff --git a/src/services/base_client.py b/src/services/base_client.py index eb5d187..fdbea09 100644 --- a/src/services/base_client.py +++ b/src/services/base_client.py @@ -1,4 +1,6 @@ -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings +from langchain_postgres import PGVector + from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger @@ -13,6 +15,18 @@ def __init__(self, service_name: str): model=settings.GEMINI_MODEL_NAME, convert_system_message_to_human=True, ) + + self.embeddings = GoogleGenerativeAIEmbeddings( + model=settings.GEMINI_EMBEDDING_MODEL, + google_api_key=settings.GOOGLE_API_KEY, + ) + + self.vector_store = PGVector( + connection=settings.DATABASE_URL, + embeddings=self.embeddings, + collection_name=settings.VECTOR_COLLECTION_NAME, + ) + logger.info(f"{service_name} initialized successfully") except Exception as e: logger.error( diff --git a/src/services/chat/chat_service.py b/src/services/chat/chat_service.py index 0a5b3cc..3295951 100644 --- a/src/services/chat/chat_service.py +++ b/src/services/chat/chat_service.py @@ -1,31 +1,23 @@ -import json from typing import List, Dict, Any, Optional from fastapi import HTTPException, status -from langchain_google_genai import GoogleGenerativeAIEmbeddings -from langchain_postgres import PGVector -from langchain_google_genai import ChatGoogleGenerativeAI from langchain.schema import HumanMessage, SystemMessage from sqlalchemy.orm import Session from datetime import datetime, timezone from neo4j import GraphDatabase +from src.core.database import get_db from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger +from src.models.request import FollowUpChatRequest +from src.models.response import FollowUpChatResponse +from src.models.database import ChatSession +from src.models.database import Message from src.services.graph_service import GraphService from src.services.auth_service import AuthService from src.services.base_client import BaseLLMClient from src.services.chat.query_decomposition import QueryDecompositionService from src.services.chat.retrieval_evaluation import RetrievalEvaluationService -from src.models.database import Document -from src.models.database import DocumentChunk -from src.models.database import DocumentStatus -from src.models.database import Message -from src.core.database import get_db -from src.models.request import FollowUpChatRequest -from src.models.response import FollowUpChatResponse -from src.models.database import ChatSession -from src.utils.utils import clean_llm_response class ChatService(BaseLLMClient): @@ -35,17 +27,6 @@ def __init__(self, db: Session = None): self.db = db or next(get_db()) self.auth_service = AuthService(self.db) - self.embeddings = GoogleGenerativeAIEmbeddings( - model=settings.GEMINI_EMBEDDING_MODEL, - google_api_key=settings.GOOGLE_API_KEY, - ) - - self.vector_store = PGVector( - connection=settings.DATABASE_URL, - embeddings=self.embeddings, - collection_name=settings.VECTOR_COLLECTION_NAME, - ) - self.driver = GraphDatabase.driver( settings.NEO4J_URI, auth=(settings.NEO4J_USER, settings.NEO4J_PASSWORD) ) diff --git a/src/services/document_service.py b/src/services/document_service.py index 98c6605..9fb66eb 100644 --- a/src/services/document_service.py +++ b/src/services/document_service.py @@ -1,3 +1,4 @@ +import tempfile import os import asyncio from typing import List, Optional, Dict, Any @@ -13,21 +14,19 @@ Docx2txtLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -import tempfile from src.core.database import get_db -from src.models.database import Document, DocumentStatus, DocumentChunk, User -from src.services.s3_service import S3Service from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger -from langchain_google_genai import GoogleGenerativeAIEmbeddings -from langchain_postgres import PGVector +from src.models.database import Document, DocumentStatus, DocumentChunk, User +from src.services.s3_service import S3Service from src.services.graph_service import GraphService +from src.services.base_client import BaseLLMClient from src.utils.utils import clean_whitespaes -class DocumentService: +class DocumentService(BaseLLMClient): SUPPORTED_MIMETYPES = { "application/pdf": ".pdf", "application/msword": ".doc", @@ -41,9 +40,7 @@ class DocumentService: def __init__( self, db: Session = next(get_db()), current_user: Optional[User] = None ): - self.db = db - self.storage_service = S3Service() - self.current_user = current_user + super().__init__("DocumentService") try: try: loop = asyncio.get_event_loop() @@ -51,16 +48,9 @@ def __init__( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - self.embeddings = GoogleGenerativeAIEmbeddings( - model=settings.GEMINI_EMBEDDING_MODEL, - google_api_key=settings.GOOGLE_API_KEY, - ) - - self.vector_store = PGVector( - connection=settings.DATABASE_URL, - embeddings=self.embeddings, - collection_name=settings.VECTOR_COLLECTION_NAME, - ) + self.db = db + self.storage_service = S3Service() + self.current_user = current_user self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=settings.CHUNK_SIZE, diff --git a/src/services/graph_service.py b/src/services/graph_service.py index d1ca178..9ca5d44 100644 --- a/src/services/graph_service.py +++ b/src/services/graph_service.py @@ -1,17 +1,17 @@ -from typing import List, Dict, Any +import re +import uuid import json +from datetime import datetime +from typing import List, Dict, Any from neo4j import GraphDatabase, Session from langchain.schema import HumanMessage, SystemMessage from src.core.config import settings -from src.models.graph import GraphKnowledge from src.core.exceptions import ExternalServiceException from src.core.logging import logger -from src.utils.utils import clean_llm_response +from src.models.graph import GraphKnowledge from src.services.base_client import BaseLLMClient -from datetime import datetime -import re -import uuid +from src.utils.utils import clean_llm_response class GraphService(BaseLLMClient): diff --git a/src/services/s3_service.py b/src/services/s3_service.py index 25b36a5..752a38b 100644 --- a/src/services/s3_service.py +++ b/src/services/s3_service.py @@ -1,7 +1,7 @@ import boto3 from botocore.exceptions import ClientError from fastapi import HTTPException, status -from typing import BinaryIO, Optional, Dict, Any, List +from typing import Optional, Dict, Any, List from concurrent.futures import ThreadPoolExecutor from src.core.config import settings diff --git a/src/services/session_service.py b/src/services/session_service.py index 0fda8bd..8de0f34 100644 --- a/src/services/session_service.py +++ b/src/services/session_service.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session, joinedload import uuid -from src.models.database import ChatSession, Message, User +from src.models.database import ChatSession, Message class SessionService: diff --git a/uv.lock b/uv.lock index 8c69be0..abba2ca 100644 --- a/uv.lock +++ b/uv.lock @@ -919,6 +919,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -1002,6 +1011,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pymupdf" }, { name = "pypdf2" }, + { name = "pytest" }, { name = "python-jose" }, { name = "python-magic" }, { name = "uvicorn" }, @@ -1027,6 +1037,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pymupdf", specifier = ">=1.26.3" }, { name = "pypdf2", specifier = ">=3.0.1" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-jose", specifier = ">=3.5.0" }, { name = "python-magic", specifier = ">=0.4.27" }, { name = "uvicorn", specifier = ">=0.35.0" }, @@ -1491,6 +1502,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/81/f457d6d361e04d061bef413749a6e1ab04d98cfeec6d8abcfe40184750f3/pgvector-0.3.6-py3-none-any.whl", hash = "sha256:f6c269b3c110ccb7496bac87202148ed18f34b390a0189c783e351062400a75a", size = 24880 }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, +] + [[package]] name = "propcache" version = "0.3.2" @@ -1869,6 +1889,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/5e/c86a5643653825d3c913719e788e41386bee415c2b87b4f955432f2de6b2/pypdf2-3.0.1-py3-none-any.whl", hash = "sha256:d16e4205cfee272fbdc0568b68d82be796540b1537508cef59388f839c191928", size = 232572 }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" From dc4269a1a1016c23b29dc373a9347fe0abe9b25c Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 08:47:30 +0530 Subject: [PATCH 11/18] create .pylintrc --- .pylintrc | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..cd30a1d --- /dev/null +++ b/.pylintrc @@ -0,0 +1,9 @@ +[MESSAGES CONTROL] +disable = + missing-module-docstring, + missing-class-docstring, + missing-function-docstring, + logging-fstring-interpolation, + raise-missing-from, + broad-exception-caught, + wrong-import-order From 617f9c8e3cef5bc3ed415c6de3d96b44a2e074d6 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 08:49:48 +0530 Subject: [PATCH 12/18] disable import error linting --- .pylintrc | 1 + tests/conftest.py | 86 ++++++++ tests/services/test_auth_service.py | 112 ++++++++++ tests/services/test_base_client.py | 38 ++++ tests/services/test_chat_service.py | 206 ++++++++++++++++++ tests/services/test_document_service.py | 227 ++++++++++++++++++++ tests/services/test_query_decomposition.py | 93 ++++++++ tests/services/test_retrieval_evaluation.py | 149 +++++++++++++ tests/services/test_s3_service.py | 192 +++++++++++++++++ tests/services/test_session_service.py | 148 +++++++++++++ 10 files changed, 1252 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/services/test_auth_service.py create mode 100644 tests/services/test_base_client.py create mode 100644 tests/services/test_chat_service.py create mode 100644 tests/services/test_document_service.py create mode 100644 tests/services/test_query_decomposition.py create mode 100644 tests/services/test_retrieval_evaluation.py create mode 100644 tests/services/test_s3_service.py create mode 100644 tests/services/test_session_service.py diff --git a/.pylintrc b/.pylintrc index cd30a1d..beca798 100644 --- a/.pylintrc +++ b/.pylintrc @@ -7,3 +7,4 @@ disable = raise-missing-from, broad-exception-caught, wrong-import-order + import-error diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..eb0a950 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,86 @@ +import pytest +from unittest.mock import MagicMock +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from neo4j import GraphDatabase +from src.core.config import settings +from src.core.database import Base +from src.models.database import User, Document, DocumentChunk, ChatSession, Message + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def db_session(db_engine): + Session = sessionmaker(bind=db_engine) + session = Session() + yield session + session.close() + + +@pytest.fixture +def mock_s3_client(): + return MagicMock() + + +@pytest.fixture +def mock_neo4j_driver(): + driver = MagicMock(spec=GraphDatabase.driver) + session = MagicMock() + driver.session.return_value.__enter__.return_value = session + driver.session.return_value.__exit__.return_value = None + return driver + + +@pytest.fixture +def mock_llm(): + llm = MagicMock() + llm.invoke.return_value = MagicMock(content="Test response") + return llm + + +@pytest.fixture +def mock_embeddings(): + embeddings = MagicMock() + embeddings.embed_documents.return_value = [[0.1] * 768] + embeddings.embed_query.return_value = [0.1] * 768 + return embeddings + + +@pytest.fixture +def test_user(db_session): + user = User( + username="testuser", email="test@example.com", hashed_password="hashedpassword" + ) + db_session.add(user) + db_session.commit() + return user + + +@pytest.fixture +def test_document(db_session, test_user): + doc = Document( + doc_id="test_doc_id", + title="Test Document", + content_type="application/pdf", + status="PENDING", + user_id=test_user.id, + ) + db_session.add(doc) + db_session.commit() + return doc + + +@pytest.fixture +def test_chat_session(db_session, test_user): + session = ChatSession( + id="test_session_id", user_id=test_user.id, title="Test Session" + ) + db_session.add(session) + db_session.commit() + return session diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py new file mode 100644 index 0000000..d4b0444 --- /dev/null +++ b/tests/services/test_auth_service.py @@ -0,0 +1,112 @@ +import pytest +from datetime import datetime, timedelta +from fastapi import HTTPException +from jose import jwt +from src.services.auth_service import AuthService +from src.core.config import settings +from src.models.database import User + + +def test_verify_password(db_session): + auth_service = AuthService(db_session) + hashed_password = auth_service.get_password_hash("testpassword") + assert auth_service.verify_password("testpassword", hashed_password) + assert not auth_service.verify_password("wrongpassword", hashed_password) + + +def test_get_user_by_username(db_session, test_user): + auth_service = AuthService(db_session) + user = auth_service.get_user_by_username("testuser") + assert user is not None + assert user.username == "testuser" + assert auth_service.get_user_by_username("nonexistent") is None + + +def test_get_user_by_email(db_session, test_user): + auth_service = AuthService(db_session) + user = auth_service.get_user_by_email("test@example.com") + assert user is not None + assert user.email == "test@example.com" + assert auth_service.get_user_by_email("nonexistent@example.com") is None + + +def test_create_user_success(db_session): + auth_service = AuthService(db_session) + user = auth_service.create_user( + username="newuser", email="new@example.com", password="password123" + ) + assert user.username == "newuser" + assert user.email == "new@example.com" + assert auth_service.verify_password("password123", user.hashed_password) + + +def test_create_user_duplicate_username(db_session, test_user): + auth_service = AuthService(db_session) + with pytest.raises(HTTPException) as exc_info: + auth_service.create_user( + username="testuser", email="another@example.com", password="password123" + ) + assert exc_info.value.status_code == 400 + assert "Username already registered" in str(exc_info.value.detail) + + +def test_create_user_duplicate_email(db_session, test_user): + auth_service = AuthService(db_session) + with pytest.raises(HTTPException) as exc_info: + auth_service.create_user( + username="anotheruser", email="test@example.com", password="password123" + ) + assert exc_info.value.status_code == 400 + assert "Email already registered" in str(exc_info.value.detail) + + +def test_authenticate_user_success(db_session): + auth_service = AuthService(db_session) + user = auth_service.create_user( + username="authuser", email="auth@example.com", password="password123" + ) + authenticated_user = auth_service.authenticate_user( + "auth@example.com", "password123" + ) + assert authenticated_user is not None + assert authenticated_user.id == user.id + + +def test_authenticate_user_failure(db_session, test_user): + auth_service = AuthService(db_session) + assert auth_service.authenticate_user("test@example.com", "wrongpassword") is None + assert ( + auth_service.authenticate_user("nonexistent@example.com", "password123") is None + ) + + +def test_create_access_token(db_session): + auth_service = AuthService(db_session) + data = {"sub": "1"} + expires_delta = timedelta(minutes=15) + token = auth_service.create_access_token(data, expires_delta) + + decoded = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + assert decoded["sub"] == "1" + assert "exp" in decoded + + +def test_get_current_user_success(db_session, test_user): + auth_service = AuthService(db_session) + token = auth_service.create_access_token({"sub": str(test_user.id)}) + user = auth_service.get_current_user(token) + assert user.id == test_user.id + + +def test_get_current_user_invalid_token(db_session): + auth_service = AuthService(db_session) + with pytest.raises(HTTPException) as exc_info: + auth_service.get_current_user("invalid_token") + assert exc_info.value.status_code == 401 + assert "Could not validate credentials" in str(exc_info.value.detail) + + +def test_get_user_s3_prefix(db_session, test_user): + auth_service = AuthService(db_session) + prefix = auth_service.get_user_s3_prefix(test_user.id) + assert prefix == f"user_{test_user.id}/" diff --git a/tests/services/test_base_client.py b/tests/services/test_base_client.py new file mode 100644 index 0000000..1719aba --- /dev/null +++ b/tests/services/test_base_client.py @@ -0,0 +1,38 @@ +import pytest +from unittest.mock import patch, MagicMock +from src.services.base_client import BaseLLMClient +from src.core.exceptions import ExternalServiceException + + +def test_base_client_initialization_success(mock_llm): + with patch( + "src.services.base_client.ChatGoogleGenerativeAI", return_value=mock_llm + ): + client = BaseLLMClient("TestService") + assert client.service_name == "TestService" + assert client.llm == mock_llm + + +def test_base_client_initialization_failure(): + with patch( + "src.services.base_client.ChatGoogleGenerativeAI", + side_effect=Exception("API Error"), + ): + with pytest.raises(ExternalServiceException) as exc_info: + BaseLLMClient("TestService") + assert "Failed to initialize testservice" in str(exc_info.value) + assert exc_info.value.service_name == "TestService" + + +def test_base_client_llm_invocation(mock_llm): + with patch( + "src.services.base_client.ChatGoogleGenerativeAI", return_value=mock_llm + ): + client = BaseLLMClient("TestService") + mock_llm.invoke.return_value = MagicMock(content="Test response") + + messages = [{"role": "user", "content": "Test message"}] + response = client.llm.invoke(messages) + + assert response.content == "Test response" + mock_llm.invoke.assert_called_once_with(messages) diff --git a/tests/services/test_chat_service.py b/tests/services/test_chat_service.py new file mode 100644 index 0000000..989ef79 --- /dev/null +++ b/tests/services/test_chat_service.py @@ -0,0 +1,206 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from datetime import datetime, timezone +from fastapi import HTTPException +from src.services.chat.chat_service import ChatService +from src.models.database import ChatSession, Message, Document, DocumentStatus +from src.models.request import FollowUpChatRequest +from src.core.exceptions import ExternalServiceException + + +@pytest.fixture +def chat_service(db_session, mock_llm, mock_embeddings, mock_neo4j_driver): + with ( + patch( + "src.services.chat.chat_service.ChatGoogleGenerativeAI", + return_value=mock_llm, + ), + patch( + "src.services.chat.chat_service.GoogleGenerativeAIEmbeddings", + return_value=mock_embeddings, + ), + patch( + "src.services.chat.chat_service.GraphDatabase.driver", + return_value=mock_neo4j_driver, + ), + patch("src.services.chat.chat_service.PGVector") as mock_pgvector, + ): + mock_pgvector.return_value = MagicMock() + service = ChatService(db_session) + return service + + +@pytest.mark.asyncio +async def test_process_query_simple(chat_service, mock_llm): + query = "What is the meaning of life?" + current_user_id = 1 + + chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ + (MagicMock(page_content="Life is 42"), 0.9) + ] + chat_service.graph_service.query_graph.return_value = [ + {"type": "Concept", "properties": {"name": "Life", "content": "Meaning"}} + ] + + mock_llm.invoke.return_value = MagicMock(content="The meaning of life is 42") + + result = await chat_service.process_query(query, "session_1", current_user_id) + + assert "message" in result + assert result["message"] == "The meaning of life is 42" + assert "context_used" in result + assert len(result["context_used"]["vector_results"]) > 0 + assert len(result["context_used"]["graph_results"]) > 0 + + +@pytest.mark.asyncio +async def test_process_query_with_decomposition(chat_service, mock_llm): + query = "Tell me about the system architecture and performance" + current_user_id = 1 + + # Mock query decomposition + chat_service._query_decomposition_service = MagicMock() + chat_service._query_decomposition_service.decompose_query.return_value = [ + "What is the system architecture?", + "How does the system perform?", + ] + + # Mock vector and graph results for both sub-queries + chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ + (MagicMock(page_content="Architecture details"), 0.9) + ] + chat_service.graph_service.query_graph.return_value = [ + {"type": "Document", "properties": {"content": "Performance metrics"}} + ] + + mock_llm.invoke.return_value = MagicMock( + content="Combined response about architecture and performance" + ) + + result = await chat_service.process_query( + query, "session_1", current_user_id, use_query_decomposition=True + ) + + assert result["message"] == "Combined response about architecture and performance" + assert "context_used" in result + assert "sub_responses" in result["context_used"] + + +@pytest.mark.asyncio +async def test_process_query_with_retrieval_evaluation(chat_service, mock_llm): + query = "How does authentication work?" + current_user_id = 1 + + chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ + (MagicMock(page_content="Auth process"), 0.8) + ] + + chat_service._retrieval_evaluation_service = MagicMock() + chat_service._retrieval_evaluation_service.evaluate_retrieval_quality.return_value = { + "overall_quality_score": 6, + "needs_improvement": True, + "suggested_improvements": { + "alternative_search_terms": ["user authentication", "login process"] + }, + } + + result = await chat_service.process_query( + query, "session_1", current_user_id, use_retrieval_evaluation=True + ) + + assert "message" in result + assert "context_used" in result + + +@pytest.mark.asyncio +async def test_follow_up_chat(chat_service, test_user, test_chat_session): + request = FollowUpChatRequest( + message="Follow up question", + referenced_node_ids=["node1", "node2"], + context_window=2, + ) + + mock_session = chat_service.driver.session.return_value.__enter__.return_value + mock_session.run.return_value = [ + {"related": {"id": "node1", "content": "Context 1"}}, + {"related": {"id": "node2", "content": "Context 2"}}, + ] + + response = await chat_service.follow_up_chat( + test_chat_session.id, request, test_user.id + ) + + assert response.response == "Placeholder response" + assert len(response.context_nodes) == 2 + assert response.referenced_entities == ["node1", "node2"] + + +def test_rename_chat_session(chat_service, test_user, test_chat_session): + new_title = "New Session Title" + result = chat_service.rename_chat_session( + test_chat_session.id, new_title, test_user.id + ) + + assert result["session_id"] == test_chat_session.id + assert result["title"] == new_title + + +def test_rename_chat_session_not_found(chat_service, test_user): + with pytest.raises(HTTPException) as exc_info: + chat_service.rename_chat_session("nonexistent", "New Title", test_user.id) + assert exc_info.value.status_code == 404 + + +def test_rename_chat_session_unauthorized(chat_service, test_chat_session): + with pytest.raises(HTTPException) as exc_info: + chat_service.rename_chat_session(test_chat_session.id, "New Title", 999) + assert exc_info.value.status_code == 403 + + +def test_delete_chat_session(chat_service, test_user, test_chat_session): + result = chat_service.delete_chat_session(test_chat_session.id, test_user.id) + + assert result["session_id"] == test_chat_session.id + assert result["status"] == "deleted" + + session = ( + chat_service.db.query(ChatSession) + .filter(ChatSession.id == test_chat_session.id) + .first() + ) + assert session is None + + +def test_delete_chat_session_not_found(chat_service, test_user): + with pytest.raises(HTTPException) as exc_info: + chat_service.delete_chat_session("nonexistent", test_user.id) + assert exc_info.value.status_code == 404 + + +def test_delete_chat_session_unauthorized(chat_service, test_chat_session): + with pytest.raises(HTTPException) as exc_info: + chat_service.delete_chat_session(test_chat_session.id, 999) + assert exc_info.value.status_code == 403 + + +def test_merge_results(chat_service): + vector_results = ["Vector result 1", "Vector result 2"] + graph_results = [ + { + "type": "Concept", + "properties": {"name": "Test", "content": "Content"}, + "relationships": [{"type": "RELATED_TO"}], + } + ] + + merged = chat_service._merge_results(vector_results, graph_results) + assert isinstance(merged, str) + assert "Vector result" in merged + assert "Type: Concept" in merged + assert "RELATED_TO" in merged + + +def test_merge_results_error_handling(chat_service): + with pytest.raises(ExternalServiceException) as exc_info: + chat_service._merge_results(None, None) + assert "Failed to merge results" in str(exc_info.value) diff --git a/tests/services/test_document_service.py b/tests/services/test_document_service.py new file mode 100644 index 0000000..57356a9 --- /dev/null +++ b/tests/services/test_document_service.py @@ -0,0 +1,227 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +import io +from fastapi import UploadFile, HTTPException +from src.services.document_service import DocumentService +from src.models.database import Document, DocumentStatus, User +from src.core.exceptions import ExternalServiceException + + +@pytest.fixture +def document_service( + db_session, mock_llm, mock_embeddings, mock_neo4j_driver, test_user +): + with ( + patch( + "src.services.document_service.ChatGoogleGenerativeAI", + return_value=mock_llm, + ), + patch( + "src.services.document_service.GoogleGenerativeAIEmbeddings", + return_value=mock_embeddings, + ), + patch( + "src.services.document_service.GraphDatabase.driver", + return_value=mock_neo4j_driver, + ), + patch("src.services.document_service.PGVector") as mock_pgvector, + patch("src.services.document_service.S3Service") as mock_s3, + ): + mock_pgvector.return_value = MagicMock() + mock_s3.return_value = MagicMock() + service = DocumentService(db_session, test_user) + return service + + +@pytest.fixture +def sample_pdf_file(): + return UploadFile( + filename="test.pdf", + file=io.BytesIO(b"PDF content"), + content_type="application/pdf", + ) + + +@pytest.fixture +def sample_docx_file(): + return UploadFile( + filename="test.docx", + file=io.BytesIO(b"DOCX content"), + content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + +@pytest.mark.asyncio +async def test_upload_documents_success(document_service, sample_pdf_file): + result = await document_service.upload_documents([sample_pdf_file]) + + assert len(result) == 1 + assert result[0]["title"] == "test.pdf" + assert result[0]["status"] == "success" + + # Verify document was created in database + doc = ( + document_service.db.query(Document).filter(Document.title == "test.pdf").first() + ) + assert doc is not None + assert doc.status == DocumentStatus.PROCESSING + + +@pytest.mark.asyncio +async def test_upload_documents_unsupported_type(document_service): + unsupported_file = UploadFile( + filename="test.xyz", file=io.BytesIO(b"content"), content_type="application/xyz" + ) + + with pytest.raises(HTTPException) as exc_info: + await document_service.upload_documents([unsupported_file]) + assert exc_info.value.status_code == 400 + assert "Unsupported file type" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_upload_documents_storage_error(document_service, sample_pdf_file): + # Mock storage service to raise an error + document_service.storage_service.upload_file.side_effect = Exception( + "Storage error" + ) + + result = await document_service.upload_documents([sample_pdf_file]) + assert result[0]["status"] == "failed" + assert "Failed to upload document" in result[0]["message"] + + # Verify document status was updated to FAILED + doc = ( + document_service.db.query(Document).filter(Document.title == "test.pdf").first() + ) + assert doc.status == DocumentStatus.FAILED + + +@pytest.mark.asyncio +async def test_index_document_success(document_service, test_document): + # Mock successful document processing + document_service.storage_service.get_file.return_value = b"Document content" + document_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"] + document_service.embeddings.embed_documents.return_value = [ + [0.1] * 768, + [0.2] * 768, + ] + + result = await document_service.index_document(test_document.doc_id) + + assert result["doc_id"] == test_document.doc_id + assert result["status"] == "INDEXED" + assert result["chunks_count"] == 2 + + # Verify document was updated in database + doc = ( + document_service.db.query(Document) + .filter(Document.doc_id == test_document.doc_id) + .first() + ) + assert doc.status == DocumentStatus.INDEXED + + +@pytest.mark.asyncio +async def test_index_document_already_indexed(document_service, test_document): + test_document.status = DocumentStatus.INDEXED + document_service.db.commit() + + result = await document_service.index_document(test_document.doc_id) + assert result["message"] == "Document already indexed" + + +@pytest.mark.asyncio +async def test_index_document_not_found(document_service): + with pytest.raises(HTTPException) as exc_info: + await document_service.index_document("nonexistent") + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_index_document_unauthorized(document_service, test_document): + # Change document user_id + test_document.user_id = 999 + document_service.db.commit() + + with pytest.raises(HTTPException) as exc_info: + await document_service.index_document(test_document.doc_id) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_index_document_processing_error(document_service, test_document): + document_service.storage_service.get_file.side_effect = Exception( + "Processing error" + ) + + with pytest.raises(HTTPException) as exc_info: + await document_service.index_document(test_document.doc_id) + assert exc_info.value.status_code == 500 + + # Verify document status was updated to FAILED + doc = ( + document_service.db.query(Document) + .filter(Document.doc_id == test_document.doc_id) + .first() + ) + assert doc.status == DocumentStatus.FAILED + + +@pytest.mark.asyncio +async def test_list_documents(document_service, test_document): + documents = await document_service.list_documents() + assert len(documents) == 1 + assert documents[0].doc_id == test_document.doc_id + + +@pytest.mark.asyncio +async def test_list_documents_with_status(document_service, test_document): + documents = await document_service.list_documents(document_status="PENDING") + assert len(documents) == 1 + assert all(doc.status == DocumentStatus.PENDING for doc in documents) + + +@pytest.mark.asyncio +async def test_list_documents_invalid_status(document_service): + with pytest.raises(HTTPException) as exc_info: + await document_service.list_documents(document_status="INVALID") + assert exc_info.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_get_document(document_service, test_document): + doc = await document_service.get_document(test_document.doc_id) + assert doc.doc_id == test_document.doc_id + + +@pytest.mark.asyncio +async def test_get_document_not_found(document_service): + with pytest.raises(HTTPException) as exc_info: + await document_service.get_document("nonexistent") + assert exc_info.value.status_code == 404 + + +def test_get_document_loader(document_service): + # Test PDF loader + loader = document_service._get_document_loader("test.pdf", "application/pdf") + assert "PyMuPDFLoader" in str(type(loader)) + + # Test DOCX loader + loader = document_service._get_document_loader( + "test.docx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + assert "Docx2txtLoader" in str(type(loader)) + + # Test CSV loader + loader = document_service._get_document_loader("test.csv", "text/csv") + assert "CSVLoader" in str(type(loader)) + + # Test TXT loader + loader = document_service._get_document_loader("test.txt", "text/plain") + assert "TextLoader" in str(type(loader)) + + # Test fallback loader + loader = document_service._get_document_loader("test.other", "application/other") + assert "UnstructuredFileLoader" in str(type(loader)) diff --git a/tests/services/test_query_decomposition.py b/tests/services/test_query_decomposition.py new file mode 100644 index 0000000..d2d2ddd --- /dev/null +++ b/tests/services/test_query_decomposition.py @@ -0,0 +1,93 @@ +import pytest +from unittest.mock import patch, MagicMock +from src.services.chat.query_decomposition import QueryDecompositionService +from langchain.schema import HumanMessage, SystemMessage + + +def test_decompose_query_simple(mock_llm): + with patch( + "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = QueryDecompositionService() + mock_llm.invoke.return_value = MagicMock( + content="What is the capital of France?" + ) + + result = service.decompose_query("What is the capital of France?") + assert len(result) == 1 + assert result[0] == "What is the capital of France?" + + # Verify correct prompt construction + mock_llm.invoke.assert_called_once() + args = mock_llm.invoke.call_args[0][0] + assert len(args) == 2 + assert isinstance(args[0], SystemMessage) + assert isinstance(args[1], HumanMessage) + assert "query decomposition assistant" in args[0].content.lower() + + +def test_decompose_query_complex(mock_llm): + with patch( + "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = QueryDecompositionService() + mock_llm.invoke.return_value = MagicMock( + content="1. What are the main features of the product?\n2. How much does it cost?\n3. What are the customer reviews?" + ) + + result = service.decompose_query( + "Tell me about the product, its pricing and customer feedback" + ) + assert len(result) == 3 + assert "main features" in result[0] + assert "cost" in result[1] + assert "reviews" in result[2] + + +def test_decompose_query_error_handling(mock_llm): + with patch( + "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = QueryDecompositionService() + mock_llm.invoke.side_effect = Exception("API Error") + + query = "What is the meaning of life?" + result = service.decompose_query(query) + + # Should return original query on error + assert len(result) == 1 + assert result[0] == query + + +def test_decompose_query_empty_response(mock_llm): + with patch( + "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = QueryDecompositionService() + mock_llm.invoke.return_value = MagicMock(content="\n\n \n") + + query = "What is the weather?" + result = service.decompose_query(query) + + # Should return original query for empty response + assert len(result) == 1 + assert result[0] == query + + +def test_decompose_query_whitespace_handling(mock_llm): + with patch( + "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = QueryDecompositionService() + mock_llm.invoke.return_value = MagicMock( + content=" 1. First question \n\n2. Second question\n 3. Third question " + ) + + result = service.decompose_query("Complex multi-part question") + assert len(result) == 3 + assert all(not q.startswith(" ") and not q.endswith(" ") for q in result) diff --git a/tests/services/test_retrieval_evaluation.py b/tests/services/test_retrieval_evaluation.py new file mode 100644 index 0000000..568c6af --- /dev/null +++ b/tests/services/test_retrieval_evaluation.py @@ -0,0 +1,149 @@ +import pytest +from unittest.mock import patch, MagicMock +import json +from src.services.chat.retrieval_evaluation import RetrievalEvaluationService +from langchain.schema import HumanMessage, SystemMessage + + +@pytest.fixture +def sample_evaluation_response(): + return { + "chunk_scores": [ + { + "chunk": "text1", + "relevance_score": 8, + "reasoning": "Directly answers the query", + }, + {"chunk": "text2", "relevance_score": 5, "reasoning": "Partially relevant"}, + ], + "missing_aspects": ["technical details"], + "redundant_information": ["repeated context"], + "suggested_improvements": { + "additional_info_needed": ["implementation steps"], + "alternative_search_terms": [ + "technical specification", + "implementation guide", + ], + }, + "overall_quality_score": 7, + "quality_summary": "Good overall coverage but missing technical details", + } + + +def test_evaluate_retrieval_quality_success(mock_llm, sample_evaluation_response): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + mock_llm.invoke.return_value = MagicMock( + content=json.dumps(sample_evaluation_response) + ) + + result = service.evaluate_retrieval_quality( + "How does the system work?", ["Text chunk 1", "Text chunk 2"] + ) + + assert result["overall_quality_score"] == 7 + assert len(result["chunk_scores"]) == 2 + assert "needs_improvement" in result + assert result["needs_improvement"] == False # Score >= 7 + + # Verify prompt construction + mock_llm.invoke.assert_called_once() + args = mock_llm.invoke.call_args[0][0] + assert len(args) == 2 + assert isinstance(args[0], SystemMessage) + assert isinstance(args[1], HumanMessage) + + +def test_evaluate_retrieval_quality_low_score(mock_llm): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + low_score_response = { + "chunk_scores": [ + {"chunk": "text", "relevance_score": 3, "reasoning": "Not relevant"} + ], + "missing_aspects": ["important context"], + "overall_quality_score": 3, + } + mock_llm.invoke.return_value = MagicMock(content=json.dumps(low_score_response)) + + result = service.evaluate_retrieval_quality("Query", ["Text"]) + assert result["needs_improvement"] == True + assert result["overall_quality_score"] == 3 + + +def test_evaluate_retrieval_quality_error_handling(mock_llm): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + mock_llm.invoke.side_effect = Exception("API Error") + + result = service.evaluate_retrieval_quality("Query", ["Text"]) + assert result["overall_quality_score"] == 0 + assert result["needs_improvement"] == True + + +def test_evaluate_retrieval_quality_invalid_json(mock_llm): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + mock_llm.invoke.return_value = MagicMock(content="Invalid JSON") + + result = service.evaluate_retrieval_quality("Query", ["Text"]) + assert result["overall_quality_score"] == 0 + assert result["needs_improvement"] == True + + +def test_improve_retrieval_suggestions(mock_llm, sample_evaluation_response): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + + # Test when improvement is needed + evaluation = { + **sample_evaluation_response, + "overall_quality_score": 5, + "needs_improvement": True, + } + + alternative_queries = service._improve_retrieval("Original query", evaluation) + assert len(alternative_queries) > 0 + assert "technical specification" in alternative_queries + assert "implementation guide" in alternative_queries + + +def test_improve_retrieval_no_improvement_needed(mock_llm): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + + evaluation = {"overall_quality_score": 9, "needs_improvement": False} + + alternative_queries = service._improve_retrieval("Original query", evaluation) + assert len(alternative_queries) == 0 + + +def test_improve_retrieval_error_handling(mock_llm): + with patch( + "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", + return_value=mock_llm, + ): + service = RetrievalEvaluationService() + + # Test with invalid evaluation data + evaluation = {"invalid": "data"} + alternative_queries = service._improve_retrieval("Original query", evaluation) + assert len(alternative_queries) == 0 diff --git a/tests/services/test_s3_service.py b/tests/services/test_s3_service.py new file mode 100644 index 0000000..66c61e0 --- /dev/null +++ b/tests/services/test_s3_service.py @@ -0,0 +1,192 @@ +import pytest +from unittest.mock import patch, MagicMock +from datetime import datetime +from botocore.exceptions import ClientError +from fastapi import HTTPException +from src.services.s3_service import S3Service + + +@pytest.fixture +def s3_service(mock_s3_client): + with patch("src.services.s3_service.boto3.client", return_value=mock_s3_client): + service = S3Service() + service.s3_client = mock_s3_client + return service + + +def test_get_user_path(s3_service): + user_path = s3_service._get_user_path(123) + assert user_path == "users/123" + + +def test_upload_file_success(s3_service): + file_data = b"test content" + file_path = "documents/test.pdf" + content_type = "application/pdf" + user_id = 123 + + full_path = s3_service.upload_file( + user_id=user_id, + file_path=file_path, + file_data=file_data, + content_type=content_type, + ) + + s3_service.s3_client.put_object.assert_called_once_with( + Bucket=s3_service.bucket_name, + Key=f"users/{user_id}/{file_path}", + Body=file_data, + ContentType=content_type, + ) + assert full_path == f"users/{user_id}/{file_path}" + + +def test_upload_file_without_content_type(s3_service): + file_data = b"test content" + file_path = "documents/test.txt" + user_id = 123 + + s3_service.upload_file(user_id=user_id, file_path=file_path, file_data=file_data) + + s3_service.s3_client.put_object.assert_called_once_with( + Bucket=s3_service.bucket_name, + Key=f"users/{user_id}/{file_path}", + Body=file_data, + ) + + +def test_upload_file_error(s3_service): + s3_service.s3_client.put_object.side_effect = ClientError( + {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "PutObject" + ) + + with pytest.raises(HTTPException) as exc_info: + s3_service.upload_file(123, "test.txt", b"content") + assert exc_info.value.status_code == 500 + assert "Failed to upload file" in str(exc_info.value.detail) + + +def test_upload_files_batch_success(s3_service): + files = [ + { + "file_path": "test1.txt", + "file_data": b"content1", + "content_type": "text/plain", + }, + { + "file_path": "test2.pdf", + "file_data": b"content2", + "content_type": "application/pdf", + }, + ] + + results = s3_service.upload_files_batch(123, files) + + assert len(results) == 2 + assert all(result["status"] == "success" for result in results) + assert s3_service.s3_client.put_object.call_count == 2 + + +def test_upload_files_batch_partial_failure(s3_service): + def mock_upload(*args, **kwargs): + if "test1.txt" in kwargs.get("Key", ""): + raise ClientError( + {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "PutObject" + ) + + s3_service.s3_client.put_object.side_effect = mock_upload + + files = [ + {"file_path": "test1.txt", "file_data": b"content1"}, + {"file_path": "test2.txt", "file_data": b"content2"}, + ] + + results = s3_service.upload_files_batch(123, files) + + assert len(results) == 2 + assert any(result["status"] == "failed" for result in results) + assert any(result["status"] == "success" for result in results) + + +def test_get_file_success(s3_service): + s3_service.s3_client.get_object.return_value = { + "Body": MagicMock(read=lambda: b"file content") + } + + content = s3_service.get_file(123, "test.txt") + assert content == b"file content" + + s3_service.s3_client.get_object.assert_called_once_with( + Bucket=s3_service.bucket_name, Key="users/123/test.txt" + ) + + +def test_get_file_not_found(s3_service): + s3_service.s3_client.get_object.side_effect = ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "GetObject" + ) + + with pytest.raises(HTTPException) as exc_info: + s3_service.get_file(123, "nonexistent.txt") + assert exc_info.value.status_code == 404 + assert "File not found" in str(exc_info.value.detail) + + +def test_get_file_error(s3_service): + s3_service.s3_client.get_object.side_effect = ClientError( + {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "GetObject" + ) + + with pytest.raises(HTTPException) as exc_info: + s3_service.get_file(123, "test.txt") + assert exc_info.value.status_code == 500 + assert "Failed to download file" in str(exc_info.value.detail) + + +def test_get_file_unauthorized_access(s3_service): + with pytest.raises(HTTPException) as exc_info: + s3_service.get_file(123, "test.txt", requesting_user_id=456) + assert exc_info.value.status_code == 403 + assert "Access denied" in str(exc_info.value.detail) + + +def test_list_user_files_success(s3_service): + mock_response = { + "Contents": [ + {"Key": "users/123/file1.txt", "Size": 100, "LastModified": datetime.now()}, + {"Key": "users/123/file2.pdf", "Size": 200, "LastModified": datetime.now()}, + ] + } + s3_service.s3_client.list_objects_v2.return_value = mock_response + + files = s3_service.list_user_files(123, requesting_user_id=123) + + assert len(files) == 2 + assert all("path" in f and "size" in f and "last_modified" in f for f in files) + assert any(f["path"] == "file1.txt" for f in files) + assert any(f["path"] == "file2.pdf" for f in files) + + +def test_list_user_files_empty(s3_service): + s3_service.s3_client.list_objects_v2.return_value = {} + + files = s3_service.list_user_files(123, requesting_user_id=123) + assert len(files) == 0 + + +def test_list_user_files_unauthorized(s3_service): + with pytest.raises(HTTPException) as exc_info: + s3_service.list_user_files(123, requesting_user_id=456) + assert exc_info.value.status_code == 403 + assert "Access denied" in str(exc_info.value.detail) + + +def test_list_user_files_error(s3_service): + s3_service.s3_client.list_objects_v2.side_effect = ClientError( + {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "ListObjectsV2" + ) + + with pytest.raises(HTTPException) as exc_info: + s3_service.list_user_files(123, requesting_user_id=123) + assert exc_info.value.status_code == 500 + assert "Failed to list files" in str(exc_info.value.detail) diff --git a/tests/services/test_session_service.py b/tests/services/test_session_service.py new file mode 100644 index 0000000..0eab3ec --- /dev/null +++ b/tests/services/test_session_service.py @@ -0,0 +1,148 @@ +import pytest +from datetime import datetime, timezone +from src.services.session_service import SessionService +from src.models.database import ChatSession, Message + + +@pytest.fixture +def session_service(db_session): + return SessionService(db_session) + + +@pytest.mark.asyncio +async def test_create_session_with_title(session_service, test_user): + title = "Test Chat Session" + session = await session_service.create_session(test_user.id, title) + + assert session.id is not None + assert session.user_id == test_user.id + assert session.title == title + + # Verify session was saved to database + db_session = ( + session_service.db.query(ChatSession) + .filter(ChatSession.id == session.id) + .first() + ) + assert db_session is not None + assert db_session.title == title + + +@pytest.mark.asyncio +async def test_create_session_without_title(session_service, test_user): + session = await session_service.create_session(test_user.id) + + assert session.id is not None + assert session.user_id == test_user.id + assert "Chat" in session.title + assert datetime.now(timezone.utc).strftime("%Y-%m-%d") in session.title + + +@pytest.mark.asyncio +async def test_get_user_sessions(session_service, test_user, test_chat_session): + # Create additional session + await session_service.create_session(test_user.id, "Second Session") + + sessions = await session_service.get_user_sessions(test_user.id) + assert len(sessions) == 2 + assert all(s.user_id == test_user.id for s in sessions) + + +@pytest.mark.asyncio +async def test_get_session(session_service, test_user, test_chat_session): + session = await session_service.get_session(test_chat_session.id, test_user.id) + + assert session is not None + assert session.id == test_chat_session.id + assert session.user_id == test_user.id + + +@pytest.mark.asyncio +async def test_get_session_not_found(session_service, test_user): + session = await session_service.get_session("nonexistent", test_user.id) + assert session is None + + +@pytest.mark.asyncio +async def test_get_session_wrong_user(session_service, test_chat_session): + wrong_user_id = test_chat_session.user_id + 1 + session = await session_service.get_session(test_chat_session.id, wrong_user_id) + assert session is None + + +@pytest.mark.asyncio +async def test_add_message(session_service, test_chat_session): + message = await session_service.add_message( + session_id=test_chat_session.id, + sender="user", + content="Test message", + context_used={"test": "context"}, + ) + + assert message.chat_session_id == test_chat_session.id + assert message.sender == "user" + assert message.content == "Test message" + assert message.context_used == {"test": "context"} + + # Verify message was saved to database + db_message = ( + session_service.db.query(Message).filter(Message.id == message.id).first() + ) + assert db_message is not None + assert db_message.content == "Test message" + + +@pytest.mark.asyncio +async def test_add_message_without_context(session_service, test_chat_session): + message = await session_service.add_message( + session_id=test_chat_session.id, sender="user", content="Test message" + ) + + assert message.context_used == {} + + +@pytest.mark.asyncio +async def test_delete_session(session_service, test_user, test_chat_session): + # Add a message to the session + await session_service.add_message(test_chat_session.id, "user", "Test message") + + await session_service.delete_session(test_chat_session.id, test_user.id) + + # Verify session and messages were deleted + session = ( + session_service.db.query(ChatSession) + .filter(ChatSession.id == test_chat_session.id) + .first() + ) + assert session is None + + messages = ( + session_service.db.query(Message) + .filter(Message.chat_session_id == test_chat_session.id) + .all() + ) + assert len(messages) == 0 + + +@pytest.mark.asyncio +async def test_delete_session_not_found(session_service, test_user): + # Should not raise any exception + await session_service.delete_session("nonexistent", test_user.id) + + +@pytest.mark.asyncio +async def test_get_session_messages(session_service, test_chat_session): + # Add multiple messages + message1 = await session_service.add_message( + test_chat_session.id, "user", "Message 1" + ) + message2 = await session_service.add_message( + test_chat_session.id, "assistant", "Message 2" + ) + + messages = await session_service.get_session_messages(test_chat_session.id) + + assert len(messages) == 2 + assert messages[0].content == "Message 1" + assert messages[1].content == "Message 2" + assert messages[0].created_at <= messages[1].created_at From 8e7da622e356b53c10c1034172890d05c9a13ab6 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 08:50:14 +0530 Subject: [PATCH 13/18] remove tests --- tests/conftest.py | 86 -------- tests/services/test_auth_service.py | 112 ---------- tests/services/test_base_client.py | 38 ---- tests/services/test_chat_service.py | 206 ------------------ tests/services/test_document_service.py | 227 -------------------- tests/services/test_query_decomposition.py | 93 -------- tests/services/test_retrieval_evaluation.py | 149 ------------- tests/services/test_s3_service.py | 192 ----------------- tests/services/test_session_service.py | 148 ------------- 9 files changed, 1251 deletions(-) delete mode 100644 tests/conftest.py delete mode 100644 tests/services/test_auth_service.py delete mode 100644 tests/services/test_base_client.py delete mode 100644 tests/services/test_chat_service.py delete mode 100644 tests/services/test_document_service.py delete mode 100644 tests/services/test_query_decomposition.py delete mode 100644 tests/services/test_retrieval_evaluation.py delete mode 100644 tests/services/test_s3_service.py delete mode 100644 tests/services/test_session_service.py diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index eb0a950..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest -from unittest.mock import MagicMock -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from neo4j import GraphDatabase -from src.core.config import settings -from src.core.database import Base -from src.models.database import User, Document, DocumentChunk, ChatSession, Message - - -@pytest.fixture -def db_engine(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - return engine - - -@pytest.fixture -def db_session(db_engine): - Session = sessionmaker(bind=db_engine) - session = Session() - yield session - session.close() - - -@pytest.fixture -def mock_s3_client(): - return MagicMock() - - -@pytest.fixture -def mock_neo4j_driver(): - driver = MagicMock(spec=GraphDatabase.driver) - session = MagicMock() - driver.session.return_value.__enter__.return_value = session - driver.session.return_value.__exit__.return_value = None - return driver - - -@pytest.fixture -def mock_llm(): - llm = MagicMock() - llm.invoke.return_value = MagicMock(content="Test response") - return llm - - -@pytest.fixture -def mock_embeddings(): - embeddings = MagicMock() - embeddings.embed_documents.return_value = [[0.1] * 768] - embeddings.embed_query.return_value = [0.1] * 768 - return embeddings - - -@pytest.fixture -def test_user(db_session): - user = User( - username="testuser", email="test@example.com", hashed_password="hashedpassword" - ) - db_session.add(user) - db_session.commit() - return user - - -@pytest.fixture -def test_document(db_session, test_user): - doc = Document( - doc_id="test_doc_id", - title="Test Document", - content_type="application/pdf", - status="PENDING", - user_id=test_user.id, - ) - db_session.add(doc) - db_session.commit() - return doc - - -@pytest.fixture -def test_chat_session(db_session, test_user): - session = ChatSession( - id="test_session_id", user_id=test_user.id, title="Test Session" - ) - db_session.add(session) - db_session.commit() - return session diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py deleted file mode 100644 index d4b0444..0000000 --- a/tests/services/test_auth_service.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -from datetime import datetime, timedelta -from fastapi import HTTPException -from jose import jwt -from src.services.auth_service import AuthService -from src.core.config import settings -from src.models.database import User - - -def test_verify_password(db_session): - auth_service = AuthService(db_session) - hashed_password = auth_service.get_password_hash("testpassword") - assert auth_service.verify_password("testpassword", hashed_password) - assert not auth_service.verify_password("wrongpassword", hashed_password) - - -def test_get_user_by_username(db_session, test_user): - auth_service = AuthService(db_session) - user = auth_service.get_user_by_username("testuser") - assert user is not None - assert user.username == "testuser" - assert auth_service.get_user_by_username("nonexistent") is None - - -def test_get_user_by_email(db_session, test_user): - auth_service = AuthService(db_session) - user = auth_service.get_user_by_email("test@example.com") - assert user is not None - assert user.email == "test@example.com" - assert auth_service.get_user_by_email("nonexistent@example.com") is None - - -def test_create_user_success(db_session): - auth_service = AuthService(db_session) - user = auth_service.create_user( - username="newuser", email="new@example.com", password="password123" - ) - assert user.username == "newuser" - assert user.email == "new@example.com" - assert auth_service.verify_password("password123", user.hashed_password) - - -def test_create_user_duplicate_username(db_session, test_user): - auth_service = AuthService(db_session) - with pytest.raises(HTTPException) as exc_info: - auth_service.create_user( - username="testuser", email="another@example.com", password="password123" - ) - assert exc_info.value.status_code == 400 - assert "Username already registered" in str(exc_info.value.detail) - - -def test_create_user_duplicate_email(db_session, test_user): - auth_service = AuthService(db_session) - with pytest.raises(HTTPException) as exc_info: - auth_service.create_user( - username="anotheruser", email="test@example.com", password="password123" - ) - assert exc_info.value.status_code == 400 - assert "Email already registered" in str(exc_info.value.detail) - - -def test_authenticate_user_success(db_session): - auth_service = AuthService(db_session) - user = auth_service.create_user( - username="authuser", email="auth@example.com", password="password123" - ) - authenticated_user = auth_service.authenticate_user( - "auth@example.com", "password123" - ) - assert authenticated_user is not None - assert authenticated_user.id == user.id - - -def test_authenticate_user_failure(db_session, test_user): - auth_service = AuthService(db_session) - assert auth_service.authenticate_user("test@example.com", "wrongpassword") is None - assert ( - auth_service.authenticate_user("nonexistent@example.com", "password123") is None - ) - - -def test_create_access_token(db_session): - auth_service = AuthService(db_session) - data = {"sub": "1"} - expires_delta = timedelta(minutes=15) - token = auth_service.create_access_token(data, expires_delta) - - decoded = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) - assert decoded["sub"] == "1" - assert "exp" in decoded - - -def test_get_current_user_success(db_session, test_user): - auth_service = AuthService(db_session) - token = auth_service.create_access_token({"sub": str(test_user.id)}) - user = auth_service.get_current_user(token) - assert user.id == test_user.id - - -def test_get_current_user_invalid_token(db_session): - auth_service = AuthService(db_session) - with pytest.raises(HTTPException) as exc_info: - auth_service.get_current_user("invalid_token") - assert exc_info.value.status_code == 401 - assert "Could not validate credentials" in str(exc_info.value.detail) - - -def test_get_user_s3_prefix(db_session, test_user): - auth_service = AuthService(db_session) - prefix = auth_service.get_user_s3_prefix(test_user.id) - assert prefix == f"user_{test_user.id}/" diff --git a/tests/services/test_base_client.py b/tests/services/test_base_client.py deleted file mode 100644 index 1719aba..0000000 --- a/tests/services/test_base_client.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from src.services.base_client import BaseLLMClient -from src.core.exceptions import ExternalServiceException - - -def test_base_client_initialization_success(mock_llm): - with patch( - "src.services.base_client.ChatGoogleGenerativeAI", return_value=mock_llm - ): - client = BaseLLMClient("TestService") - assert client.service_name == "TestService" - assert client.llm == mock_llm - - -def test_base_client_initialization_failure(): - with patch( - "src.services.base_client.ChatGoogleGenerativeAI", - side_effect=Exception("API Error"), - ): - with pytest.raises(ExternalServiceException) as exc_info: - BaseLLMClient("TestService") - assert "Failed to initialize testservice" in str(exc_info.value) - assert exc_info.value.service_name == "TestService" - - -def test_base_client_llm_invocation(mock_llm): - with patch( - "src.services.base_client.ChatGoogleGenerativeAI", return_value=mock_llm - ): - client = BaseLLMClient("TestService") - mock_llm.invoke.return_value = MagicMock(content="Test response") - - messages = [{"role": "user", "content": "Test message"}] - response = client.llm.invoke(messages) - - assert response.content == "Test response" - mock_llm.invoke.assert_called_once_with(messages) diff --git a/tests/services/test_chat_service.py b/tests/services/test_chat_service.py deleted file mode 100644 index 989ef79..0000000 --- a/tests/services/test_chat_service.py +++ /dev/null @@ -1,206 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock, AsyncMock -from datetime import datetime, timezone -from fastapi import HTTPException -from src.services.chat.chat_service import ChatService -from src.models.database import ChatSession, Message, Document, DocumentStatus -from src.models.request import FollowUpChatRequest -from src.core.exceptions import ExternalServiceException - - -@pytest.fixture -def chat_service(db_session, mock_llm, mock_embeddings, mock_neo4j_driver): - with ( - patch( - "src.services.chat.chat_service.ChatGoogleGenerativeAI", - return_value=mock_llm, - ), - patch( - "src.services.chat.chat_service.GoogleGenerativeAIEmbeddings", - return_value=mock_embeddings, - ), - patch( - "src.services.chat.chat_service.GraphDatabase.driver", - return_value=mock_neo4j_driver, - ), - patch("src.services.chat.chat_service.PGVector") as mock_pgvector, - ): - mock_pgvector.return_value = MagicMock() - service = ChatService(db_session) - return service - - -@pytest.mark.asyncio -async def test_process_query_simple(chat_service, mock_llm): - query = "What is the meaning of life?" - current_user_id = 1 - - chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ - (MagicMock(page_content="Life is 42"), 0.9) - ] - chat_service.graph_service.query_graph.return_value = [ - {"type": "Concept", "properties": {"name": "Life", "content": "Meaning"}} - ] - - mock_llm.invoke.return_value = MagicMock(content="The meaning of life is 42") - - result = await chat_service.process_query(query, "session_1", current_user_id) - - assert "message" in result - assert result["message"] == "The meaning of life is 42" - assert "context_used" in result - assert len(result["context_used"]["vector_results"]) > 0 - assert len(result["context_used"]["graph_results"]) > 0 - - -@pytest.mark.asyncio -async def test_process_query_with_decomposition(chat_service, mock_llm): - query = "Tell me about the system architecture and performance" - current_user_id = 1 - - # Mock query decomposition - chat_service._query_decomposition_service = MagicMock() - chat_service._query_decomposition_service.decompose_query.return_value = [ - "What is the system architecture?", - "How does the system perform?", - ] - - # Mock vector and graph results for both sub-queries - chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ - (MagicMock(page_content="Architecture details"), 0.9) - ] - chat_service.graph_service.query_graph.return_value = [ - {"type": "Document", "properties": {"content": "Performance metrics"}} - ] - - mock_llm.invoke.return_value = MagicMock( - content="Combined response about architecture and performance" - ) - - result = await chat_service.process_query( - query, "session_1", current_user_id, use_query_decomposition=True - ) - - assert result["message"] == "Combined response about architecture and performance" - assert "context_used" in result - assert "sub_responses" in result["context_used"] - - -@pytest.mark.asyncio -async def test_process_query_with_retrieval_evaluation(chat_service, mock_llm): - query = "How does authentication work?" - current_user_id = 1 - - chat_service.vector_store.similarity_search_with_score_by_vector.return_value = [ - (MagicMock(page_content="Auth process"), 0.8) - ] - - chat_service._retrieval_evaluation_service = MagicMock() - chat_service._retrieval_evaluation_service.evaluate_retrieval_quality.return_value = { - "overall_quality_score": 6, - "needs_improvement": True, - "suggested_improvements": { - "alternative_search_terms": ["user authentication", "login process"] - }, - } - - result = await chat_service.process_query( - query, "session_1", current_user_id, use_retrieval_evaluation=True - ) - - assert "message" in result - assert "context_used" in result - - -@pytest.mark.asyncio -async def test_follow_up_chat(chat_service, test_user, test_chat_session): - request = FollowUpChatRequest( - message="Follow up question", - referenced_node_ids=["node1", "node2"], - context_window=2, - ) - - mock_session = chat_service.driver.session.return_value.__enter__.return_value - mock_session.run.return_value = [ - {"related": {"id": "node1", "content": "Context 1"}}, - {"related": {"id": "node2", "content": "Context 2"}}, - ] - - response = await chat_service.follow_up_chat( - test_chat_session.id, request, test_user.id - ) - - assert response.response == "Placeholder response" - assert len(response.context_nodes) == 2 - assert response.referenced_entities == ["node1", "node2"] - - -def test_rename_chat_session(chat_service, test_user, test_chat_session): - new_title = "New Session Title" - result = chat_service.rename_chat_session( - test_chat_session.id, new_title, test_user.id - ) - - assert result["session_id"] == test_chat_session.id - assert result["title"] == new_title - - -def test_rename_chat_session_not_found(chat_service, test_user): - with pytest.raises(HTTPException) as exc_info: - chat_service.rename_chat_session("nonexistent", "New Title", test_user.id) - assert exc_info.value.status_code == 404 - - -def test_rename_chat_session_unauthorized(chat_service, test_chat_session): - with pytest.raises(HTTPException) as exc_info: - chat_service.rename_chat_session(test_chat_session.id, "New Title", 999) - assert exc_info.value.status_code == 403 - - -def test_delete_chat_session(chat_service, test_user, test_chat_session): - result = chat_service.delete_chat_session(test_chat_session.id, test_user.id) - - assert result["session_id"] == test_chat_session.id - assert result["status"] == "deleted" - - session = ( - chat_service.db.query(ChatSession) - .filter(ChatSession.id == test_chat_session.id) - .first() - ) - assert session is None - - -def test_delete_chat_session_not_found(chat_service, test_user): - with pytest.raises(HTTPException) as exc_info: - chat_service.delete_chat_session("nonexistent", test_user.id) - assert exc_info.value.status_code == 404 - - -def test_delete_chat_session_unauthorized(chat_service, test_chat_session): - with pytest.raises(HTTPException) as exc_info: - chat_service.delete_chat_session(test_chat_session.id, 999) - assert exc_info.value.status_code == 403 - - -def test_merge_results(chat_service): - vector_results = ["Vector result 1", "Vector result 2"] - graph_results = [ - { - "type": "Concept", - "properties": {"name": "Test", "content": "Content"}, - "relationships": [{"type": "RELATED_TO"}], - } - ] - - merged = chat_service._merge_results(vector_results, graph_results) - assert isinstance(merged, str) - assert "Vector result" in merged - assert "Type: Concept" in merged - assert "RELATED_TO" in merged - - -def test_merge_results_error_handling(chat_service): - with pytest.raises(ExternalServiceException) as exc_info: - chat_service._merge_results(None, None) - assert "Failed to merge results" in str(exc_info.value) diff --git a/tests/services/test_document_service.py b/tests/services/test_document_service.py deleted file mode 100644 index 57356a9..0000000 --- a/tests/services/test_document_service.py +++ /dev/null @@ -1,227 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock, AsyncMock -import io -from fastapi import UploadFile, HTTPException -from src.services.document_service import DocumentService -from src.models.database import Document, DocumentStatus, User -from src.core.exceptions import ExternalServiceException - - -@pytest.fixture -def document_service( - db_session, mock_llm, mock_embeddings, mock_neo4j_driver, test_user -): - with ( - patch( - "src.services.document_service.ChatGoogleGenerativeAI", - return_value=mock_llm, - ), - patch( - "src.services.document_service.GoogleGenerativeAIEmbeddings", - return_value=mock_embeddings, - ), - patch( - "src.services.document_service.GraphDatabase.driver", - return_value=mock_neo4j_driver, - ), - patch("src.services.document_service.PGVector") as mock_pgvector, - patch("src.services.document_service.S3Service") as mock_s3, - ): - mock_pgvector.return_value = MagicMock() - mock_s3.return_value = MagicMock() - service = DocumentService(db_session, test_user) - return service - - -@pytest.fixture -def sample_pdf_file(): - return UploadFile( - filename="test.pdf", - file=io.BytesIO(b"PDF content"), - content_type="application/pdf", - ) - - -@pytest.fixture -def sample_docx_file(): - return UploadFile( - filename="test.docx", - file=io.BytesIO(b"DOCX content"), - content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ) - - -@pytest.mark.asyncio -async def test_upload_documents_success(document_service, sample_pdf_file): - result = await document_service.upload_documents([sample_pdf_file]) - - assert len(result) == 1 - assert result[0]["title"] == "test.pdf" - assert result[0]["status"] == "success" - - # Verify document was created in database - doc = ( - document_service.db.query(Document).filter(Document.title == "test.pdf").first() - ) - assert doc is not None - assert doc.status == DocumentStatus.PROCESSING - - -@pytest.mark.asyncio -async def test_upload_documents_unsupported_type(document_service): - unsupported_file = UploadFile( - filename="test.xyz", file=io.BytesIO(b"content"), content_type="application/xyz" - ) - - with pytest.raises(HTTPException) as exc_info: - await document_service.upload_documents([unsupported_file]) - assert exc_info.value.status_code == 400 - assert "Unsupported file type" in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_upload_documents_storage_error(document_service, sample_pdf_file): - # Mock storage service to raise an error - document_service.storage_service.upload_file.side_effect = Exception( - "Storage error" - ) - - result = await document_service.upload_documents([sample_pdf_file]) - assert result[0]["status"] == "failed" - assert "Failed to upload document" in result[0]["message"] - - # Verify document status was updated to FAILED - doc = ( - document_service.db.query(Document).filter(Document.title == "test.pdf").first() - ) - assert doc.status == DocumentStatus.FAILED - - -@pytest.mark.asyncio -async def test_index_document_success(document_service, test_document): - # Mock successful document processing - document_service.storage_service.get_file.return_value = b"Document content" - document_service.text_splitter.split_text.return_value = ["chunk1", "chunk2"] - document_service.embeddings.embed_documents.return_value = [ - [0.1] * 768, - [0.2] * 768, - ] - - result = await document_service.index_document(test_document.doc_id) - - assert result["doc_id"] == test_document.doc_id - assert result["status"] == "INDEXED" - assert result["chunks_count"] == 2 - - # Verify document was updated in database - doc = ( - document_service.db.query(Document) - .filter(Document.doc_id == test_document.doc_id) - .first() - ) - assert doc.status == DocumentStatus.INDEXED - - -@pytest.mark.asyncio -async def test_index_document_already_indexed(document_service, test_document): - test_document.status = DocumentStatus.INDEXED - document_service.db.commit() - - result = await document_service.index_document(test_document.doc_id) - assert result["message"] == "Document already indexed" - - -@pytest.mark.asyncio -async def test_index_document_not_found(document_service): - with pytest.raises(HTTPException) as exc_info: - await document_service.index_document("nonexistent") - assert exc_info.value.status_code == 404 - - -@pytest.mark.asyncio -async def test_index_document_unauthorized(document_service, test_document): - # Change document user_id - test_document.user_id = 999 - document_service.db.commit() - - with pytest.raises(HTTPException) as exc_info: - await document_service.index_document(test_document.doc_id) - assert exc_info.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_index_document_processing_error(document_service, test_document): - document_service.storage_service.get_file.side_effect = Exception( - "Processing error" - ) - - with pytest.raises(HTTPException) as exc_info: - await document_service.index_document(test_document.doc_id) - assert exc_info.value.status_code == 500 - - # Verify document status was updated to FAILED - doc = ( - document_service.db.query(Document) - .filter(Document.doc_id == test_document.doc_id) - .first() - ) - assert doc.status == DocumentStatus.FAILED - - -@pytest.mark.asyncio -async def test_list_documents(document_service, test_document): - documents = await document_service.list_documents() - assert len(documents) == 1 - assert documents[0].doc_id == test_document.doc_id - - -@pytest.mark.asyncio -async def test_list_documents_with_status(document_service, test_document): - documents = await document_service.list_documents(document_status="PENDING") - assert len(documents) == 1 - assert all(doc.status == DocumentStatus.PENDING for doc in documents) - - -@pytest.mark.asyncio -async def test_list_documents_invalid_status(document_service): - with pytest.raises(HTTPException) as exc_info: - await document_service.list_documents(document_status="INVALID") - assert exc_info.value.status_code == 400 - - -@pytest.mark.asyncio -async def test_get_document(document_service, test_document): - doc = await document_service.get_document(test_document.doc_id) - assert doc.doc_id == test_document.doc_id - - -@pytest.mark.asyncio -async def test_get_document_not_found(document_service): - with pytest.raises(HTTPException) as exc_info: - await document_service.get_document("nonexistent") - assert exc_info.value.status_code == 404 - - -def test_get_document_loader(document_service): - # Test PDF loader - loader = document_service._get_document_loader("test.pdf", "application/pdf") - assert "PyMuPDFLoader" in str(type(loader)) - - # Test DOCX loader - loader = document_service._get_document_loader( - "test.docx", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ) - assert "Docx2txtLoader" in str(type(loader)) - - # Test CSV loader - loader = document_service._get_document_loader("test.csv", "text/csv") - assert "CSVLoader" in str(type(loader)) - - # Test TXT loader - loader = document_service._get_document_loader("test.txt", "text/plain") - assert "TextLoader" in str(type(loader)) - - # Test fallback loader - loader = document_service._get_document_loader("test.other", "application/other") - assert "UnstructuredFileLoader" in str(type(loader)) diff --git a/tests/services/test_query_decomposition.py b/tests/services/test_query_decomposition.py deleted file mode 100644 index d2d2ddd..0000000 --- a/tests/services/test_query_decomposition.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from src.services.chat.query_decomposition import QueryDecompositionService -from langchain.schema import HumanMessage, SystemMessage - - -def test_decompose_query_simple(mock_llm): - with patch( - "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = QueryDecompositionService() - mock_llm.invoke.return_value = MagicMock( - content="What is the capital of France?" - ) - - result = service.decompose_query("What is the capital of France?") - assert len(result) == 1 - assert result[0] == "What is the capital of France?" - - # Verify correct prompt construction - mock_llm.invoke.assert_called_once() - args = mock_llm.invoke.call_args[0][0] - assert len(args) == 2 - assert isinstance(args[0], SystemMessage) - assert isinstance(args[1], HumanMessage) - assert "query decomposition assistant" in args[0].content.lower() - - -def test_decompose_query_complex(mock_llm): - with patch( - "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = QueryDecompositionService() - mock_llm.invoke.return_value = MagicMock( - content="1. What are the main features of the product?\n2. How much does it cost?\n3. What are the customer reviews?" - ) - - result = service.decompose_query( - "Tell me about the product, its pricing and customer feedback" - ) - assert len(result) == 3 - assert "main features" in result[0] - assert "cost" in result[1] - assert "reviews" in result[2] - - -def test_decompose_query_error_handling(mock_llm): - with patch( - "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = QueryDecompositionService() - mock_llm.invoke.side_effect = Exception("API Error") - - query = "What is the meaning of life?" - result = service.decompose_query(query) - - # Should return original query on error - assert len(result) == 1 - assert result[0] == query - - -def test_decompose_query_empty_response(mock_llm): - with patch( - "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = QueryDecompositionService() - mock_llm.invoke.return_value = MagicMock(content="\n\n \n") - - query = "What is the weather?" - result = service.decompose_query(query) - - # Should return original query for empty response - assert len(result) == 1 - assert result[0] == query - - -def test_decompose_query_whitespace_handling(mock_llm): - with patch( - "src.services.chat.query_decomposition.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = QueryDecompositionService() - mock_llm.invoke.return_value = MagicMock( - content=" 1. First question \n\n2. Second question\n 3. Third question " - ) - - result = service.decompose_query("Complex multi-part question") - assert len(result) == 3 - assert all(not q.startswith(" ") and not q.endswith(" ") for q in result) diff --git a/tests/services/test_retrieval_evaluation.py b/tests/services/test_retrieval_evaluation.py deleted file mode 100644 index 568c6af..0000000 --- a/tests/services/test_retrieval_evaluation.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import json -from src.services.chat.retrieval_evaluation import RetrievalEvaluationService -from langchain.schema import HumanMessage, SystemMessage - - -@pytest.fixture -def sample_evaluation_response(): - return { - "chunk_scores": [ - { - "chunk": "text1", - "relevance_score": 8, - "reasoning": "Directly answers the query", - }, - {"chunk": "text2", "relevance_score": 5, "reasoning": "Partially relevant"}, - ], - "missing_aspects": ["technical details"], - "redundant_information": ["repeated context"], - "suggested_improvements": { - "additional_info_needed": ["implementation steps"], - "alternative_search_terms": [ - "technical specification", - "implementation guide", - ], - }, - "overall_quality_score": 7, - "quality_summary": "Good overall coverage but missing technical details", - } - - -def test_evaluate_retrieval_quality_success(mock_llm, sample_evaluation_response): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - mock_llm.invoke.return_value = MagicMock( - content=json.dumps(sample_evaluation_response) - ) - - result = service.evaluate_retrieval_quality( - "How does the system work?", ["Text chunk 1", "Text chunk 2"] - ) - - assert result["overall_quality_score"] == 7 - assert len(result["chunk_scores"]) == 2 - assert "needs_improvement" in result - assert result["needs_improvement"] == False # Score >= 7 - - # Verify prompt construction - mock_llm.invoke.assert_called_once() - args = mock_llm.invoke.call_args[0][0] - assert len(args) == 2 - assert isinstance(args[0], SystemMessage) - assert isinstance(args[1], HumanMessage) - - -def test_evaluate_retrieval_quality_low_score(mock_llm): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - low_score_response = { - "chunk_scores": [ - {"chunk": "text", "relevance_score": 3, "reasoning": "Not relevant"} - ], - "missing_aspects": ["important context"], - "overall_quality_score": 3, - } - mock_llm.invoke.return_value = MagicMock(content=json.dumps(low_score_response)) - - result = service.evaluate_retrieval_quality("Query", ["Text"]) - assert result["needs_improvement"] == True - assert result["overall_quality_score"] == 3 - - -def test_evaluate_retrieval_quality_error_handling(mock_llm): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - mock_llm.invoke.side_effect = Exception("API Error") - - result = service.evaluate_retrieval_quality("Query", ["Text"]) - assert result["overall_quality_score"] == 0 - assert result["needs_improvement"] == True - - -def test_evaluate_retrieval_quality_invalid_json(mock_llm): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - mock_llm.invoke.return_value = MagicMock(content="Invalid JSON") - - result = service.evaluate_retrieval_quality("Query", ["Text"]) - assert result["overall_quality_score"] == 0 - assert result["needs_improvement"] == True - - -def test_improve_retrieval_suggestions(mock_llm, sample_evaluation_response): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - - # Test when improvement is needed - evaluation = { - **sample_evaluation_response, - "overall_quality_score": 5, - "needs_improvement": True, - } - - alternative_queries = service._improve_retrieval("Original query", evaluation) - assert len(alternative_queries) > 0 - assert "technical specification" in alternative_queries - assert "implementation guide" in alternative_queries - - -def test_improve_retrieval_no_improvement_needed(mock_llm): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - - evaluation = {"overall_quality_score": 9, "needs_improvement": False} - - alternative_queries = service._improve_retrieval("Original query", evaluation) - assert len(alternative_queries) == 0 - - -def test_improve_retrieval_error_handling(mock_llm): - with patch( - "src.services.chat.retrieval_evaluation.ChatGoogleGenerativeAI", - return_value=mock_llm, - ): - service = RetrievalEvaluationService() - - # Test with invalid evaluation data - evaluation = {"invalid": "data"} - alternative_queries = service._improve_retrieval("Original query", evaluation) - assert len(alternative_queries) == 0 diff --git a/tests/services/test_s3_service.py b/tests/services/test_s3_service.py deleted file mode 100644 index 66c61e0..0000000 --- a/tests/services/test_s3_service.py +++ /dev/null @@ -1,192 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from datetime import datetime -from botocore.exceptions import ClientError -from fastapi import HTTPException -from src.services.s3_service import S3Service - - -@pytest.fixture -def s3_service(mock_s3_client): - with patch("src.services.s3_service.boto3.client", return_value=mock_s3_client): - service = S3Service() - service.s3_client = mock_s3_client - return service - - -def test_get_user_path(s3_service): - user_path = s3_service._get_user_path(123) - assert user_path == "users/123" - - -def test_upload_file_success(s3_service): - file_data = b"test content" - file_path = "documents/test.pdf" - content_type = "application/pdf" - user_id = 123 - - full_path = s3_service.upload_file( - user_id=user_id, - file_path=file_path, - file_data=file_data, - content_type=content_type, - ) - - s3_service.s3_client.put_object.assert_called_once_with( - Bucket=s3_service.bucket_name, - Key=f"users/{user_id}/{file_path}", - Body=file_data, - ContentType=content_type, - ) - assert full_path == f"users/{user_id}/{file_path}" - - -def test_upload_file_without_content_type(s3_service): - file_data = b"test content" - file_path = "documents/test.txt" - user_id = 123 - - s3_service.upload_file(user_id=user_id, file_path=file_path, file_data=file_data) - - s3_service.s3_client.put_object.assert_called_once_with( - Bucket=s3_service.bucket_name, - Key=f"users/{user_id}/{file_path}", - Body=file_data, - ) - - -def test_upload_file_error(s3_service): - s3_service.s3_client.put_object.side_effect = ClientError( - {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "PutObject" - ) - - with pytest.raises(HTTPException) as exc_info: - s3_service.upload_file(123, "test.txt", b"content") - assert exc_info.value.status_code == 500 - assert "Failed to upload file" in str(exc_info.value.detail) - - -def test_upload_files_batch_success(s3_service): - files = [ - { - "file_path": "test1.txt", - "file_data": b"content1", - "content_type": "text/plain", - }, - { - "file_path": "test2.pdf", - "file_data": b"content2", - "content_type": "application/pdf", - }, - ] - - results = s3_service.upload_files_batch(123, files) - - assert len(results) == 2 - assert all(result["status"] == "success" for result in results) - assert s3_service.s3_client.put_object.call_count == 2 - - -def test_upload_files_batch_partial_failure(s3_service): - def mock_upload(*args, **kwargs): - if "test1.txt" in kwargs.get("Key", ""): - raise ClientError( - {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "PutObject" - ) - - s3_service.s3_client.put_object.side_effect = mock_upload - - files = [ - {"file_path": "test1.txt", "file_data": b"content1"}, - {"file_path": "test2.txt", "file_data": b"content2"}, - ] - - results = s3_service.upload_files_batch(123, files) - - assert len(results) == 2 - assert any(result["status"] == "failed" for result in results) - assert any(result["status"] == "success" for result in results) - - -def test_get_file_success(s3_service): - s3_service.s3_client.get_object.return_value = { - "Body": MagicMock(read=lambda: b"file content") - } - - content = s3_service.get_file(123, "test.txt") - assert content == b"file content" - - s3_service.s3_client.get_object.assert_called_once_with( - Bucket=s3_service.bucket_name, Key="users/123/test.txt" - ) - - -def test_get_file_not_found(s3_service): - s3_service.s3_client.get_object.side_effect = ClientError( - {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, "GetObject" - ) - - with pytest.raises(HTTPException) as exc_info: - s3_service.get_file(123, "nonexistent.txt") - assert exc_info.value.status_code == 404 - assert "File not found" in str(exc_info.value.detail) - - -def test_get_file_error(s3_service): - s3_service.s3_client.get_object.side_effect = ClientError( - {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "GetObject" - ) - - with pytest.raises(HTTPException) as exc_info: - s3_service.get_file(123, "test.txt") - assert exc_info.value.status_code == 500 - assert "Failed to download file" in str(exc_info.value.detail) - - -def test_get_file_unauthorized_access(s3_service): - with pytest.raises(HTTPException) as exc_info: - s3_service.get_file(123, "test.txt", requesting_user_id=456) - assert exc_info.value.status_code == 403 - assert "Access denied" in str(exc_info.value.detail) - - -def test_list_user_files_success(s3_service): - mock_response = { - "Contents": [ - {"Key": "users/123/file1.txt", "Size": 100, "LastModified": datetime.now()}, - {"Key": "users/123/file2.pdf", "Size": 200, "LastModified": datetime.now()}, - ] - } - s3_service.s3_client.list_objects_v2.return_value = mock_response - - files = s3_service.list_user_files(123, requesting_user_id=123) - - assert len(files) == 2 - assert all("path" in f and "size" in f and "last_modified" in f for f in files) - assert any(f["path"] == "file1.txt" for f in files) - assert any(f["path"] == "file2.pdf" for f in files) - - -def test_list_user_files_empty(s3_service): - s3_service.s3_client.list_objects_v2.return_value = {} - - files = s3_service.list_user_files(123, requesting_user_id=123) - assert len(files) == 0 - - -def test_list_user_files_unauthorized(s3_service): - with pytest.raises(HTTPException) as exc_info: - s3_service.list_user_files(123, requesting_user_id=456) - assert exc_info.value.status_code == 403 - assert "Access denied" in str(exc_info.value.detail) - - -def test_list_user_files_error(s3_service): - s3_service.s3_client.list_objects_v2.side_effect = ClientError( - {"Error": {"Code": "InternalError", "Message": "S3 Error"}}, "ListObjectsV2" - ) - - with pytest.raises(HTTPException) as exc_info: - s3_service.list_user_files(123, requesting_user_id=123) - assert exc_info.value.status_code == 500 - assert "Failed to list files" in str(exc_info.value.detail) diff --git a/tests/services/test_session_service.py b/tests/services/test_session_service.py deleted file mode 100644 index 0eab3ec..0000000 --- a/tests/services/test_session_service.py +++ /dev/null @@ -1,148 +0,0 @@ -import pytest -from datetime import datetime, timezone -from src.services.session_service import SessionService -from src.models.database import ChatSession, Message - - -@pytest.fixture -def session_service(db_session): - return SessionService(db_session) - - -@pytest.mark.asyncio -async def test_create_session_with_title(session_service, test_user): - title = "Test Chat Session" - session = await session_service.create_session(test_user.id, title) - - assert session.id is not None - assert session.user_id == test_user.id - assert session.title == title - - # Verify session was saved to database - db_session = ( - session_service.db.query(ChatSession) - .filter(ChatSession.id == session.id) - .first() - ) - assert db_session is not None - assert db_session.title == title - - -@pytest.mark.asyncio -async def test_create_session_without_title(session_service, test_user): - session = await session_service.create_session(test_user.id) - - assert session.id is not None - assert session.user_id == test_user.id - assert "Chat" in session.title - assert datetime.now(timezone.utc).strftime("%Y-%m-%d") in session.title - - -@pytest.mark.asyncio -async def test_get_user_sessions(session_service, test_user, test_chat_session): - # Create additional session - await session_service.create_session(test_user.id, "Second Session") - - sessions = await session_service.get_user_sessions(test_user.id) - assert len(sessions) == 2 - assert all(s.user_id == test_user.id for s in sessions) - - -@pytest.mark.asyncio -async def test_get_session(session_service, test_user, test_chat_session): - session = await session_service.get_session(test_chat_session.id, test_user.id) - - assert session is not None - assert session.id == test_chat_session.id - assert session.user_id == test_user.id - - -@pytest.mark.asyncio -async def test_get_session_not_found(session_service, test_user): - session = await session_service.get_session("nonexistent", test_user.id) - assert session is None - - -@pytest.mark.asyncio -async def test_get_session_wrong_user(session_service, test_chat_session): - wrong_user_id = test_chat_session.user_id + 1 - session = await session_service.get_session(test_chat_session.id, wrong_user_id) - assert session is None - - -@pytest.mark.asyncio -async def test_add_message(session_service, test_chat_session): - message = await session_service.add_message( - session_id=test_chat_session.id, - sender="user", - content="Test message", - context_used={"test": "context"}, - ) - - assert message.chat_session_id == test_chat_session.id - assert message.sender == "user" - assert message.content == "Test message" - assert message.context_used == {"test": "context"} - - # Verify message was saved to database - db_message = ( - session_service.db.query(Message).filter(Message.id == message.id).first() - ) - assert db_message is not None - assert db_message.content == "Test message" - - -@pytest.mark.asyncio -async def test_add_message_without_context(session_service, test_chat_session): - message = await session_service.add_message( - session_id=test_chat_session.id, sender="user", content="Test message" - ) - - assert message.context_used == {} - - -@pytest.mark.asyncio -async def test_delete_session(session_service, test_user, test_chat_session): - # Add a message to the session - await session_service.add_message(test_chat_session.id, "user", "Test message") - - await session_service.delete_session(test_chat_session.id, test_user.id) - - # Verify session and messages were deleted - session = ( - session_service.db.query(ChatSession) - .filter(ChatSession.id == test_chat_session.id) - .first() - ) - assert session is None - - messages = ( - session_service.db.query(Message) - .filter(Message.chat_session_id == test_chat_session.id) - .all() - ) - assert len(messages) == 0 - - -@pytest.mark.asyncio -async def test_delete_session_not_found(session_service, test_user): - # Should not raise any exception - await session_service.delete_session("nonexistent", test_user.id) - - -@pytest.mark.asyncio -async def test_get_session_messages(session_service, test_chat_session): - # Add multiple messages - message1 = await session_service.add_message( - test_chat_session.id, "user", "Message 1" - ) - message2 = await session_service.add_message( - test_chat_session.id, "assistant", "Message 2" - ) - - messages = await session_service.get_session_messages(test_chat_session.id) - - assert len(messages) == 2 - assert messages[0].content == "Message 1" - assert messages[1].content == "Message 2" - assert messages[0].created_at <= messages[1].created_at From feae909d0222d4d6256026edc2e4f777a443441b Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 09:09:37 +0530 Subject: [PATCH 14/18] update pylintrc --- .pylintrc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.pylintrc b/.pylintrc index beca798..1f6b2d9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,10 +1,2 @@ [MESSAGES CONTROL] -disable = - missing-module-docstring, - missing-class-docstring, - missing-function-docstring, - logging-fstring-interpolation, - raise-missing-from, - broad-exception-caught, - wrong-import-order - import-error +disable = all From 175d6d07399ffdb2361361f444da924318397067 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Tue, 22 Jul 2025 09:13:55 +0530 Subject: [PATCH 15/18] Update pylint.yml --- .github/workflows/pylint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 6e21c90..09f6091 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -20,4 +20,4 @@ jobs: pip install pylint - name: Analysing the code with pylint run: | - pylint $(git ls-files '*.py') + pylint src || echo "No Python files found. Skipping pylint." From c0495ac137f22bbcca67d43c3f9b0ffbac1c8011 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Tue, 22 Jul 2025 09:17:04 +0530 Subject: [PATCH 16/18] Delete .github/workflows/pylint.yml --- .github/workflows/pylint.yml | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 .github/workflows/pylint.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index 09f6091..0000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Pylint - -on: [push] - -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install pylint - - name: Analysing the code with pylint - run: | - pylint src || echo "No Python files found. Skipping pylint." From 58972cac743764573209cc71e14eeb8f1fc04b58 Mon Sep 17 00:00:00 2001 From: Piyush Choudhari Date: Tue, 22 Jul 2025 11:18:40 +0530 Subject: [PATCH 17/18] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 0dc6e15..7d415b2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # KnowFlow --- -![Gemini_Generated_Image_j5414ij5414ij541](https://github.com/user-attachments/assets/032a9bc6-b0fb-437d-93cd-a9f063220e03) # Deployment --- From 1f75be176f2ccb6e61ab7c2515c6c9a29b83a212 Mon Sep 17 00:00:00 2001 From: capybara-brain346 Date: Tue, 22 Jul 2025 16:27:50 +0530 Subject: [PATCH 18/18] add .dockerignore --- .dockerignore | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++ .pylintrc | 2 -- pyproject.toml | 3 ++ uv.lock | 3 ++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 .dockerignore delete mode 100644 .pylintrc diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..2ed403b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,91 @@ +# Git +.git +.gitignore +.gitattributes + +# CI +.codeclimate.yml +.travis.yml +.taskcluster.yml + +# Docker +docker-compose.yml +Dockerfile +.docker +.dockerignore + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +data/ +logs/ +saved_models/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Virtual environment +.env +.venv/ +venv/ + +# PyCharm +.idea + +# Python mode for VIM +.ropeproject +**/.ropeproject + +# Vim swap files +**/*.swp + +# VS Code +.vscode/ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 1f6b2d9..0000000 --- a/.pylintrc +++ /dev/null @@ -1,2 +0,0 @@ -[MESSAGES CONTROL] -disable = all diff --git a/pyproject.toml b/pyproject.toml index c1995d8..64fe16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,6 @@ dependencies = [ "pymupdf>=1.26.3", "pytest>=8.4.1", ] + +[dependency-groups] +dev = [] diff --git a/uv.lock b/uv.lock index abba2ca..119e55c 100644 --- a/uv.lock +++ b/uv.lock @@ -1043,6 +1043,9 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.35.0" }, ] +[package.metadata.requires-dev] +dev = [] + [[package]] name = "langchain" version = "0.3.26"