diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..2ed403b
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,91 @@
+# Git
+.git
+.gitignore
+.gitattributes
+
+# CI
+.codeclimate.yml
+.travis.yml
+.taskcluster.yml
+
+# Docker
+docker-compose.yml
+Dockerfile
+.docker
+.dockerignore
+
+# Byte-compiled / optimized / DLL files
+**/__pycache__/
+**/*.py[cod]
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+data/
+logs/
+saved_models/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.cache
+nosetests.xml
+coverage.xml
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Virtual environment
+.env
+.venv/
+venv/
+
+# PyCharm
+.idea
+
+# Python mode for VIM
+.ropeproject
+**/.ropeproject
+
+# Vim swap files
+**/*.swp
+
+# VS Code
+.vscode/
\ No newline at end of file
diff --git a/.github/workflows/deploy-to-ecs.yml b/.github/workflows/deploy-to-ecs.yml
new file mode 100644
index 0000000..ebcdfcd
--- /dev/null
+++ b/.github/workflows/deploy-to-ecs.yml
@@ -0,0 +1,60 @@
+name: Build & Deploy to ECS
+
+on:
+ push:
+ branches:
+ - dev
+
+env:
+ AWS_REGION: ${{ vars.AWS_REGION }}
+ ECR_REPOSITORY: ${{ vars.ECR_REPOSITORY }}
+ ECS_CLUSTER: ${{ vars.ECS_CLUSTER }}
+ ECS_SERVICE: ${{ vars.ECS_SERVICE }}
+ CONTAINER_NAME: ${{ vars.CONTAINER_NAME }}
+ ECS_TASK_DEF: ${{ vars.ECS_TASK_DEF }}
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ environment: production
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Configure AWS credentials
+ uses: aws-actions/configure-aws-credentials@v2
+ with:
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ aws-region: ${{ env.AWS_REGION }}
+
+ - name: Login to Amazon ECR
+ id: login-ecr
+ uses: aws-actions/amazon-ecr-login@v1
+
+ - name: Build & Push Docker Image
+ id: build-image
+ env:
+ ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
+ IMAGE_TAG: ${{ github.sha }}
+ run: |
+ docker build -f docker/Dockerfile.backend -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
+ docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
+ echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT
+
+ - name: Render Updated ECS Task Definition
+ id: render-task-def
+ uses: aws-actions/amazon-ecs-render-task-definition@v1
+ with:
+ task-definition-family: ${{ env.ECS_TASK_DEF }}
+ container-name: ${{ env.CONTAINER_NAME }}
+ image: ${{ steps.build-image.outputs.image }}
+
+ - name: Deploy to Amazon ECS
+ uses: aws-actions/amazon-ecs-deploy-task-definition@v1
+ with:
+ task-definition: ${{ steps.render-task-def.outputs.task-definition }}
+ service: ${{ env.ECS_SERVICE }}
+ cluster: ${{ env.ECS_CLUSTER }}
+ wait-for-service-stability: true
diff --git a/README.md b/README.md
index 6d80416..7d415b2 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,13 @@
-# KNOWFLOW
+# KnowFlow
+---
-
+# Deployment
+---
+
+
+# AI Infra
+---
+
KnowFlow is a powerful hybrid Retrieval-Augmented Generation (RAG) system that combines semantic search with knowledge graph capabilities for intelligent document processing and querying.
diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend
index 3cda427..e4d8064 100644
--- a/docker/Dockerfile.backend
+++ b/docker/Dockerfile.backend
@@ -13,7 +13,7 @@ RUN sh /uv-installer.sh && rm /uv-installer.sh
ENV PATH="/root/.local/bin/:$PATH"
-COPY .env requirements.txt pyproject.toml uv.lock .python-version ./
+COPY requirements.txt pyproject.toml uv.lock .python-version ./
RUN uv sync --locked
diff --git a/docker/Dockerfile.neo4j b/docker/Dockerfile.neo4j
index bf72468..be2abd5 100644
--- a/docker/Dockerfile.neo4j
+++ b/docker/Dockerfile.neo4j
@@ -3,6 +3,7 @@ FROM neo4j:5.15.0
ENV NEO4J_AUTH=neo4j/Pstm!tr0ae#123
ENV NEO4J_PLUGINS=["apoc"]
ENV NEO4J_dbms_security_procedures_unrestricted=apoc.*
+ENV NEO4J_dbms_default__listen__address=0.0.0.0
ENV NEO4J_dbms_memory_heap_initial__size=512m
ENV NEO4J_dbms_memory_heap_max__size=2G
diff --git a/docs/FLOWCHART.md b/docs/FLOWCHART.md
new file mode 100644
index 0000000..3f3dd00
--- /dev/null
+++ b/docs/FLOWCHART.md
@@ -0,0 +1,267 @@
+# KnowFlow Technical Architecture
+
+## Document Processing Pipeline
+
+### 1. Document Upload & Initial Processing
+
+- **Supported Formats**: PDF, DOCX, CSV, TXT
+- **Document Loaders**:
+ - PDF: `PyMuPDFLoader`
+ - DOCX: `Docx2txtLoader`
+ - CSV: `CSVLoader`
+ - TXT: `TextLoader`
+ - Fallback: `UnstructuredFileLoader`
+
+### 2. Text Processing
+
+- **Text Splitter**: `RecursiveCharacterTextSplitter`
+ ```python
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=1000,
+ chunk_overlap=200,
+ separators=["\n\n", "\n", " ", ""],
+ keep_separator=True
+ )
+ ```
+
+### 3. Vector Generation
+
+- **Embedding Model**: Gemini Embedding Model
+- **Vector Store**: PGVector (PostgreSQL)
+ ```python
+ vector_store = PGVector(
+ connection=DATABASE_URL,
+ embeddings=embeddings,
+ collection_name=VECTOR_COLLECTION_NAME
+ )
+ ```
+
+### 4. Knowledge Graph Construction
+
+- **Database**: Neo4j
+- **Node Types**:
+ - Document: Main document node
+ - Section: Document sections
+ - Entity: Named entities
+ - Concept: Key ideas/terms
+ - Tag: Categories
+- **Relationship Types**:
+ - CONTAINS: Hierarchy
+ - RELATED_TO: General connections
+ - MENTIONS: References
+ - HAS_TAG: Classifications
+
+## Query Processing Pipeline
+
+### 1. Query Analysis
+
+- **Query Decomposition**:
+ ```python
+ messages = [
+ SystemMessage(content="Query decomposition prompt"),
+ HumanMessage(content=query)
+ ]
+ sub_questions = llm.invoke(messages)
+ ```
+
+### 2. Retrieval Process
+
+- **Vector Search**:
+ ```python
+ docs_and_scores = vector_store.similarity_search_with_score_by_vector(
+ embedding=query_embedding,
+ k=TOP_K_RESULTS,
+ filter=filter_dict
+ )
+ ```
+- **Graph Search**:
+ - Cypher query generation
+ - Pattern matching
+ - Context retrieval
+
+### 3. Retrieval Evaluation
+
+- **Quality Metrics**:
+ ```json
+ {
+ "chunk_scores": [
+ {"chunk": "text", "relevance_score": 0-10, "reasoning": "explanation"}
+ ],
+ "missing_aspects": ["list of missing information"],
+ "redundant_information": ["list of redundancies"],
+ "suggested_improvements": {
+ "additional_info_needed": ["missing info"],
+ "alternative_search_terms": ["suggested terms"]
+ },
+ "overall_quality_score": 0-10
+ }
+ ```
+
+### 4. Response Generation
+
+- **Context Assembly**:
+
+ ```python
+ def _merge_results(vector_results, graph_results):
+ graph_texts = []
+ for result in graph_results:
+ text = f"Type: {result.get('type')}\n"
+ text += f"Properties: {result.get('properties')}\n"
+ text += f"Relationships: {result.get('relationships')}"
+ graph_texts.append(text)
+
+ all_texts = vector_results[:3] + graph_texts[:3]
+ return "\n\n".join(all_texts)
+ ```
+
+- **LLM Prompt Structure**:
+ ```python
+ system_prompt = f"""
+ You are a helpful, reasoning assistant. Answer based on:
+ Context: {context}
+ Guidelines:
+ - Rephrase/summarize from context
+ - Use reasoning to clarify
+ - No fabrication
+ - Indicate if answer unavailable
+ """
+ ```
+
+## Storage Architecture
+
+### 1. Vector Store (PostgreSQL)
+
+- **Schema**:
+ ```sql
+ CREATE TABLE document_chunks (
+ id SERIAL PRIMARY KEY,
+ document_id INTEGER,
+ chunk_index INTEGER,
+ content TEXT,
+ embedding_vector vector(768),
+ metadata JSONB
+ );
+ ```
+
+### 2. Graph Database (Neo4j)
+
+- **Node Properties**:
+ ```json
+ {
+ "id": "unique_id",
+ "type": "node_type",
+ "name": "node_name",
+ "content": "text_content",
+ "created_at": "timestamp"
+ }
+ ```
+- **Relationship Properties**:
+ ```json
+ {
+ "type": "relationship_type",
+ "context": "relationship_context",
+ "confidence": 0.0-1.0,
+ "created_at": "timestamp"
+ }
+ ```
+
+### 3. Document Storage (S3)
+
+- **Structure**:
+ ```
+ s3://bucket/
+ └── users/
+ └── {user_id}/
+ └── documents/
+ └── {doc_id}.{extension}
+ ```
+
+## Mermaid Technical Flowchart
+
+```mermaid
+flowchart TB
+ %% Document Processing Pipeline
+ subgraph Document_Processing["Document Processing Pipeline"]
+ direction TB
+ Upload["Document Upload
Supported: PDF, DOCX, CSV, TXT"]
+
+ subgraph Processing["Document Processing"]
+ Loader["Document Loaders
PyMuPDF/Docx2txt/CSV/Text"]
+ Splitter["Text Splitter
RecursiveCharacterTextSplitter
chunk_size=1000, overlap=200"]
+ VectorGen["Vector Generation
Gemini Embedding Model"]
+ end
+
+ subgraph Storage["Storage Layer"]
+ S3Store["S3 Storage
Raw Documents"]
+ PGVector["PGVector Store
- document_id
- chunk_index
- embedding_vector
- metadata"]
+ Neo4j["Neo4j Graph DB
Node Types: Document, Section,
Entity, Concept, Tag
Relations: CONTAINS, RELATED_TO,
MENTIONS, HAS_TAG"]
+ end
+
+ Upload --> Loader
+ Loader --> Splitter
+ Splitter --> VectorGen
+ VectorGen --> PGVector
+ Upload --> S3Store
+ Splitter --> Neo4j
+ end
+
+ %% Query Processing Pipeline
+ subgraph Query_Processing["Query Processing Pipeline"]
+ direction TB
+ Query["User Query"]
+
+ subgraph Query_Analysis["Query Analysis"]
+ QDecomp["Query Decomposition
Complex Query → Sub-questions
Using Gemini Pro"]
+ QEmbed["Query Embedding
Gemini Embedding Model"]
+ end
+
+ subgraph Retrieval["Retrieval Layer"]
+ VecSearch["Vector Search
Similarity Search with Scores
TOP_K=5"]
+ GraphSearch["Graph Search
Cypher Query Generation
Pattern Matching"]
+ end
+
+ subgraph Evaluation["Retrieval Evaluation"]
+ RelCheck["Relevance Check
Score: 0-10"]
+ ImproveSuggestions["Improvement Suggestions
- Missing Aspects
- Alternative Terms"]
+ end
+
+ Query --> QDecomp
+ QDecomp --> QEmbed
+ QEmbed --> VecSearch
+ Query --> GraphSearch
+ VecSearch --> RelCheck
+ GraphSearch --> RelCheck
+ RelCheck --> ImproveSuggestions
+ ImproveSuggestions -.-> QEmbed
+ end
+
+ %% Response Generation
+ subgraph Response_Gen["Response Generation"]
+ direction TB
+ Context["Context Assembly
Vector + Graph Results"]
+ LLMPrompt["LLM Prompt Construction
System + Context + Query"]
+ Response["Response Generation
Gemini Pro"]
+
+ Context --> LLMPrompt
+ LLMPrompt --> Response
+ end
+
+ %% Data Connections
+ PGVector --> VecSearch
+ Neo4j --> GraphSearch
+ VecSearch --> Context
+ GraphSearch --> Context
+
+ %% External Services
+ subgraph External["External Services"]
+ direction LR
+ Gemini["Google Gemini Pro
- Chat Completion
- Embeddings"]
+ AWS["AWS S3
Document Storage"]
+ end
+
+ %% Service Connections
+ VectorGen --> Gemini
+ QEmbed --> Gemini
+ Response --> Gemini
+ S3Store --> AWS
+```
diff --git a/ecs/task-definition.json b/ecs/task-definition.json
new file mode 100644
index 0000000..c40bf2b
--- /dev/null
+++ b/ecs/task-definition.json
@@ -0,0 +1,94 @@
+{
+ "taskDefinition": {
+ "taskDefinitionArn": "arn:aws:ecs:ap-south-1:953685791553:task-definition/knowflow-backend-task:1",
+ "containerDefinitions": [
+ {
+ "name": "knowflow-backend",
+ "image": "953685791553.dkr.ecr.ap-south-1.amazonaws.com/knowflow-backend:latest",
+ "cpu": 0,
+ "portMappings": [
+ {
+ "containerPort": 8000,
+ "hostPort": 8000,
+ "protocol": "tcp",
+ "name": "knowflow-backend-8000-tcp",
+ "appProtocol": "http"
+ }
+ ],
+ "essential": true,
+ "secrets": [
+ {
+ "name": "AWS_ACCESS_KEY_ID",
+ "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_ACCESS_KEY_ID::"
+ },
+ {
+ "name": "AWS_SECRET_ACCESS_KEY",
+ "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_SECRET_ACCESS_KEY::"
+ }
+ ],
+ "environmentFiles": [],
+ "mountPoints": [],
+ "volumesFrom": [],
+ "ulimits": [],
+ "healthCheck": {
+ "command": [
+ "CMD-SHELL",
+ "curl -f http://localhost:8000/health || exit 1"
+ ],
+ "interval": 10,
+ "timeout": 5,
+ "retries": 3
+ },
+ "systemControls": []
+ }
+ ],
+ "family": "knowflow-backend-task",
+ "taskRoleArn": "arn:aws:iam::953685791553:role/ECSTaskRole",
+ "executionRoleArn": "arn:aws:iam::953685791553:role/ecsTaskExecutionRole",
+ "networkMode": "awsvpc",
+ "revision": 1,
+ "volumes": [],
+ "status": "ACTIVE",
+ "requiresAttributes": [
+ {
+ "name": "com.amazonaws.ecs.capability.docker-remote-api.1.24"
+ },
+ {
+ "name": "com.amazonaws.ecs.capability.ecr-auth"
+ },
+ {
+ "name": "com.amazonaws.ecs.capability.task-iam-role"
+ },
+ {
+ "name": "ecs.capability.container-health-check"
+ },
+ {
+ "name": "ecs.capability.execution-role-ecr-pull"
+ },
+ {
+ "name": "com.amazonaws.ecs.capability.docker-remote-api.1.18"
+ },
+ {
+ "name": "ecs.capability.task-eni"
+ }
+ ],
+ "placementConstraints": [],
+ "compatibilities": [
+ "EC2",
+ "FARGATE"
+ ],
+ "runtimePlatform": {
+ "cpuArchitecture": "X86_64",
+ "operatingSystemFamily": "LINUX"
+ },
+ "requiresCompatibilities": [
+ "FARGATE"
+ ],
+ "cpu": "2048",
+ "memory": "8192",
+ "registeredAt": "2025-07-20T08:14:39.290000+05:30",
+ "registeredBy": "arn:aws:iam::953685791553:root",
+ "enableFaultInjection": false
+ },
+ "tags": []
+}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 7b97a7b..64fe16c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,4 +26,8 @@ dependencies = [
"constructs>=10.4.2",
"psycopg2-binary>=2.9.10",
"pymupdf>=1.26.3",
+ "pytest>=8.4.1",
]
+
+[dependency-groups]
+dev = []
diff --git a/src/core/auth.py b/src/core/auth.py
index 73083d1..d144f00 100644
--- a/src/core/auth.py
+++ b/src/core/auth.py
@@ -1,11 +1,11 @@
from typing import Annotated
-from fastapi import Depends, HTTPException, status
+from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from src.core.database import get_db
-from src.services.auth_service import AuthService
from src.models.database import User
+from src.services.auth_service import AuthService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token")
diff --git a/src/core/config.py b/src/core/config.py
index 5a59fad..275bcaa 100644
--- a/src/core/config.py
+++ b/src/core/config.py
@@ -1,8 +1,8 @@
+import json
+import boto3
from functools import lru_cache
from pathlib import Path
from typing import Optional, List
-import json
-import boto3
from botocore.exceptions import ClientError
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -65,12 +65,14 @@ class Settings(BaseSettings):
# AI Models
GOOGLE_API_KEY: str = Field(default="")
- GEMINI_EMBEDDING_MODEL: str = Field(default="", env="GEMINI_EMBEDDING_MODEL")
- GEMINI_MODEL_NAME: str = Field(default="", env="GEMINI_MODEL_NAME")
+ GEMINI_EMBEDDING_MODEL: str = Field(
+ default="models/embedding-001", env="GEMINI_EMBEDDING_MODEL"
+ )
+ GEMINI_MODEL_NAME: str = Field(default="gemini-2.0-flash", env="GEMINI_MODEL_NAME")
# Vector Store
VECTOR_COLLECTION_NAME: str = Field(
- default="knowflow_vecotr_db", env="VECTOR_COLLECTION_NAME"
+ default="knowflow_vector_db", env="VECTOR_COLLECTION_NAME"
)
CHUNK_SIZE: int = Field(default=700)
CHUNK_OVERLAP: int = Field(default=80)
diff --git a/src/core/database.py b/src/core/database.py
index a22dc1c..62ec7d0 100644
--- a/src/core/database.py
+++ b/src/core/database.py
@@ -1,9 +1,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
-from neo4j import GraphDatabase
from src.core.config import settings
-from src.core.logging import logger
Base = declarative_base()
diff --git a/src/core/exceptions.py b/src/core/exceptions.py
index d607ad9..0d1efcd 100644
--- a/src/core/exceptions.py
+++ b/src/core/exceptions.py
@@ -1,5 +1,5 @@
-from typing import Any, Dict, Optional
from fastapi import status
+from typing import Any, Dict, Optional
class AppException(Exception):
diff --git a/src/core/logging.py b/src/core/logging.py
index 21c2e5d..93f2f54 100644
--- a/src/core/logging.py
+++ b/src/core/logging.py
@@ -1,6 +1,6 @@
+import sys
import logging
import logging.handlers
-import sys
from pathlib import Path
from typing import Optional
diff --git a/src/core/middleware.py b/src/core/middleware.py
index 0b80f14..cb71bf4 100644
--- a/src/core/middleware.py
+++ b/src/core/middleware.py
@@ -1,13 +1,10 @@
-from datetime import datetime, timezone
import time
import uuid
-from typing import Callable, Optional
-
+from datetime import datetime, timezone
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
-from starlette.types import Message
from src.core.config import settings
from src.core.logging import logger
diff --git a/src/main.py b/src/main.py
index dd758f2..a2a2fce 100644
--- a/src/main.py
+++ b/src/main.py
@@ -2,15 +2,15 @@
from fastapi import FastAPI
from src.core.config import settings
+from src.core.middleware import setup_middleware
+from src.core.database import init_db
+from src.core.logging import logger
from src.routes import (
auth_routes,
chat_routes,
document_routes,
session_routes,
)
-from src.core.middleware import setup_middleware
-from src.core.database import init_db
-from src.core.logging import logger
@asynccontextmanager
diff --git a/src/models/database.py b/src/models/database.py
index fa37ec7..de1b2b6 100644
--- a/src/models/database.py
+++ b/src/models/database.py
@@ -1,9 +1,10 @@
+import enum
+import sqlalchemy
from datetime import datetime, timezone
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Text, Enum
from sqlalchemy.orm import relationship
-import enum
-import sqlalchemy
from sqlalchemy.sql import func
+
from src.core.database import Base
diff --git a/src/models/graph.py b/src/models/graph.py
index f306c6d..cba5f4d 100644
--- a/src/models/graph.py
+++ b/src/models/graph.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Any, Optional
+from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
diff --git a/src/models/request.py b/src/models/request.py
index cd76d38..cacff9d 100644
--- a/src/models/request.py
+++ b/src/models/request.py
@@ -1,4 +1,4 @@
-from pydantic import BaseModel, EmailStr, Field, constr
+from pydantic import BaseModel, EmailStr, Field
from typing import Optional, Dict, Any, List
from datetime import datetime
diff --git a/src/models/response.py b/src/models/response.py
index bbd2f67..1ee337c 100644
--- a/src/models/response.py
+++ b/src/models/response.py
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
+
from src.models.database import DocumentStatus
diff --git a/src/routes/auth_routes.py b/src/routes/auth_routes.py
index 681ab22..18830b6 100644
--- a/src/routes/auth_routes.py
+++ b/src/routes/auth_routes.py
@@ -1,11 +1,10 @@
from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
-from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
-from sqlalchemy.orm import Session
+from fastapi.security import OAuth2PasswordRequestForm
+from src.core.auth import get_current_user, get_auth_service
from src.core.config import settings
-from src.core.database import get_db
from src.models.request import UserLogin, UserRegister
from src.models.response import (
TokenResponse,
@@ -13,9 +12,8 @@
RegisterResponse,
MessageResponse,
)
-from src.services.auth_service import AuthService
from src.models.database import User
-from src.core.auth import get_current_user, get_auth_service
+from src.services.auth_service import AuthService
router = APIRouter()
diff --git a/src/routes/chat_routes.py b/src/routes/chat_routes.py
index 3466b8c..632526f 100644
--- a/src/routes/chat_routes.py
+++ b/src/routes/chat_routes.py
@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
-from src.services.chat.chat_service import ChatService
+
+from src.core.auth import get_current_user
+from src.core.database import get_db
from src.core.exceptions import ExternalServiceException
from src.core.logging import logger
from src.models.request import ChatRequest, FollowUpChatRequest, RenameChatRequest
@@ -10,9 +12,8 @@
RenameChatResponse,
DeleteChatResponse,
)
-from src.core.auth import get_current_user
from src.models.database import User
-from src.core.database import get_db
+from src.services.chat.chat_service import ChatService
from src.services.session_service import SessionService
diff --git a/src/routes/document_routes.py b/src/routes/document_routes.py
index d6730ad..b6e3d7f 100644
--- a/src/routes/document_routes.py
+++ b/src/routes/document_routes.py
@@ -2,13 +2,13 @@
from fastapi import APIRouter, Depends, UploadFile, BackgroundTasks, Query
from src.core.auth import get_current_user
-from src.services.document_service import DocumentService
from src.models.database import User
from src.models.request import DocumentIndexRequest
from src.models.response import (
DocumentIndexResponse,
MultiDocumentUploadResponse,
)
+from src.services.document_service import DocumentService
router = APIRouter()
diff --git a/src/routes/session_routes.py b/src/routes/session_routes.py
index 3e401f3..1db1dd1 100644
--- a/src/routes/session_routes.py
+++ b/src/routes/session_routes.py
@@ -1,17 +1,16 @@
-from typing import List, Optional
+from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from src.core.database import get_db
+from src.core.auth import get_current_user
+from src.models.database import User
from src.models.request import CreateSessionRequest, SendMessageRequest
from src.models.response import (
ChatSessionResponse,
ChatSessionListResponse,
- MessageResponse,
)
from src.services.session_service import SessionService
-from src.core.auth import get_current_user
-from src.models.database import User
router = APIRouter()
diff --git a/src/services/auth_service.py b/src/services/auth_service.py
index c9d1957..1b5e547 100644
--- a/src/services/auth_service.py
+++ b/src/services/auth_service.py
@@ -7,7 +7,7 @@
from fastapi.security import OAuth2PasswordBearer
from src.core.config import settings
-from src.models.database import User, Document, DocumentChunk, DocumentStatus
+from src.models.database import User
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
diff --git a/src/services/base_client.py b/src/services/base_client.py
index eb5d187..fdbea09 100644
--- a/src/services/base_client.py
+++ b/src/services/base_client.py
@@ -1,4 +1,6 @@
-from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
+from langchain_postgres import PGVector
+
from src.core.config import settings
from src.core.exceptions import ExternalServiceException
from src.core.logging import logger
@@ -13,6 +15,18 @@ def __init__(self, service_name: str):
model=settings.GEMINI_MODEL_NAME,
convert_system_message_to_human=True,
)
+
+ self.embeddings = GoogleGenerativeAIEmbeddings(
+ model=settings.GEMINI_EMBEDDING_MODEL,
+ google_api_key=settings.GOOGLE_API_KEY,
+ )
+
+ self.vector_store = PGVector(
+ connection=settings.DATABASE_URL,
+ embeddings=self.embeddings,
+ collection_name=settings.VECTOR_COLLECTION_NAME,
+ )
+
logger.info(f"{service_name} initialized successfully")
except Exception as e:
logger.error(
diff --git a/src/services/chat/chat_service.py b/src/services/chat/chat_service.py
index 0a5b3cc..3295951 100644
--- a/src/services/chat/chat_service.py
+++ b/src/services/chat/chat_service.py
@@ -1,31 +1,23 @@
-import json
from typing import List, Dict, Any, Optional
from fastapi import HTTPException, status
-from langchain_google_genai import GoogleGenerativeAIEmbeddings
-from langchain_postgres import PGVector
-from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage, SystemMessage
from sqlalchemy.orm import Session
from datetime import datetime, timezone
from neo4j import GraphDatabase
+from src.core.database import get_db
from src.core.config import settings
from src.core.exceptions import ExternalServiceException
from src.core.logging import logger
+from src.models.request import FollowUpChatRequest
+from src.models.response import FollowUpChatResponse
+from src.models.database import ChatSession
+from src.models.database import Message
from src.services.graph_service import GraphService
from src.services.auth_service import AuthService
from src.services.base_client import BaseLLMClient
from src.services.chat.query_decomposition import QueryDecompositionService
from src.services.chat.retrieval_evaluation import RetrievalEvaluationService
-from src.models.database import Document
-from src.models.database import DocumentChunk
-from src.models.database import DocumentStatus
-from src.models.database import Message
-from src.core.database import get_db
-from src.models.request import FollowUpChatRequest
-from src.models.response import FollowUpChatResponse
-from src.models.database import ChatSession
-from src.utils.utils import clean_llm_response
class ChatService(BaseLLMClient):
@@ -35,17 +27,6 @@ def __init__(self, db: Session = None):
self.db = db or next(get_db())
self.auth_service = AuthService(self.db)
- self.embeddings = GoogleGenerativeAIEmbeddings(
- model=settings.GEMINI_EMBEDDING_MODEL,
- google_api_key=settings.GOOGLE_API_KEY,
- )
-
- self.vector_store = PGVector(
- connection=settings.DATABASE_URL,
- embeddings=self.embeddings,
- collection_name=settings.VECTOR_COLLECTION_NAME,
- )
-
self.driver = GraphDatabase.driver(
settings.NEO4J_URI, auth=(settings.NEO4J_USER, settings.NEO4J_PASSWORD)
)
diff --git a/src/services/document_service.py b/src/services/document_service.py
index 98c6605..9fb66eb 100644
--- a/src/services/document_service.py
+++ b/src/services/document_service.py
@@ -1,3 +1,4 @@
+import tempfile
import os
import asyncio
from typing import List, Optional, Dict, Any
@@ -13,21 +14,19 @@
Docx2txtLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
-import tempfile
from src.core.database import get_db
-from src.models.database import Document, DocumentStatus, DocumentChunk, User
-from src.services.s3_service import S3Service
from src.core.config import settings
from src.core.exceptions import ExternalServiceException
from src.core.logging import logger
-from langchain_google_genai import GoogleGenerativeAIEmbeddings
-from langchain_postgres import PGVector
+from src.models.database import Document, DocumentStatus, DocumentChunk, User
+from src.services.s3_service import S3Service
from src.services.graph_service import GraphService
+from src.services.base_client import BaseLLMClient
from src.utils.utils import clean_whitespaes
-class DocumentService:
+class DocumentService(BaseLLMClient):
SUPPORTED_MIMETYPES = {
"application/pdf": ".pdf",
"application/msword": ".doc",
@@ -41,9 +40,7 @@ class DocumentService:
def __init__(
self, db: Session = next(get_db()), current_user: Optional[User] = None
):
- self.db = db
- self.storage_service = S3Service()
- self.current_user = current_user
+ super().__init__("DocumentService")
try:
try:
loop = asyncio.get_event_loop()
@@ -51,16 +48,9 @@ def __init__(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
- self.embeddings = GoogleGenerativeAIEmbeddings(
- model=settings.GEMINI_EMBEDDING_MODEL,
- google_api_key=settings.GOOGLE_API_KEY,
- )
-
- self.vector_store = PGVector(
- connection=settings.DATABASE_URL,
- embeddings=self.embeddings,
- collection_name=settings.VECTOR_COLLECTION_NAME,
- )
+ self.db = db
+ self.storage_service = S3Service()
+ self.current_user = current_user
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=settings.CHUNK_SIZE,
diff --git a/src/services/graph_service.py b/src/services/graph_service.py
index d1ca178..9ca5d44 100644
--- a/src/services/graph_service.py
+++ b/src/services/graph_service.py
@@ -1,17 +1,17 @@
-from typing import List, Dict, Any
+import re
+import uuid
import json
+from datetime import datetime
+from typing import List, Dict, Any
from neo4j import GraphDatabase, Session
from langchain.schema import HumanMessage, SystemMessage
from src.core.config import settings
-from src.models.graph import GraphKnowledge
from src.core.exceptions import ExternalServiceException
from src.core.logging import logger
-from src.utils.utils import clean_llm_response
+from src.models.graph import GraphKnowledge
from src.services.base_client import BaseLLMClient
-from datetime import datetime
-import re
-import uuid
+from src.utils.utils import clean_llm_response
class GraphService(BaseLLMClient):
diff --git a/src/services/s3_service.py b/src/services/s3_service.py
index 25b36a5..752a38b 100644
--- a/src/services/s3_service.py
+++ b/src/services/s3_service.py
@@ -1,7 +1,7 @@
import boto3
from botocore.exceptions import ClientError
from fastapi import HTTPException, status
-from typing import BinaryIO, Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List
from concurrent.futures import ThreadPoolExecutor
from src.core.config import settings
diff --git a/src/services/session_service.py b/src/services/session_service.py
index 0fda8bd..8de0f34 100644
--- a/src/services/session_service.py
+++ b/src/services/session_service.py
@@ -3,7 +3,7 @@
from sqlalchemy.orm import Session, joinedload
import uuid
-from src.models.database import ChatSession, Message, User
+from src.models.database import ChatSession, Message
class SessionService:
diff --git a/uv.lock b/uv.lock
index 8c69be0..119e55c 100644
--- a/uv.lock
+++ b/uv.lock
@@ -919,6 +919,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
]
+[[package]]
+name = "iniconfig"
+version = "2.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
+]
+
[[package]]
name = "jinja2"
version = "3.1.6"
@@ -1002,6 +1011,7 @@ dependencies = [
{ name = "pydantic-settings" },
{ name = "pymupdf" },
{ name = "pypdf2" },
+ { name = "pytest" },
{ name = "python-jose" },
{ name = "python-magic" },
{ name = "uvicorn" },
@@ -1027,11 +1037,15 @@ requires-dist = [
{ name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "pymupdf", specifier = ">=1.26.3" },
{ name = "pypdf2", specifier = ">=3.0.1" },
+ { name = "pytest", specifier = ">=8.4.1" },
{ name = "python-jose", specifier = ">=3.5.0" },
{ name = "python-magic", specifier = ">=0.4.27" },
{ name = "uvicorn", specifier = ">=0.35.0" },
]
+[package.metadata.requires-dev]
+dev = []
+
[[package]]
name = "langchain"
version = "0.3.26"
@@ -1491,6 +1505,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/81/f457d6d361e04d061bef413749a6e1ab04d98cfeec6d8abcfe40184750f3/pgvector-0.3.6-py3-none-any.whl", hash = "sha256:f6c269b3c110ccb7496bac87202148ed18f34b390a0189c783e351062400a75a", size = 24880 },
]
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 },
+]
+
[[package]]
name = "propcache"
version = "0.3.2"
@@ -1869,6 +1892,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/5e/c86a5643653825d3c913719e788e41386bee415c2b87b4f955432f2de6b2/pypdf2-3.0.1-py3-none-any.whl", hash = "sha256:d16e4205cfee272fbdc0568b68d82be796540b1537508cef59388f839c191928", size = 232572 },
]
+[[package]]
+name = "pytest"
+version = "8.4.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "pygments" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 },
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"