diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..2ed403b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,91 @@ +# Git +.git +.gitignore +.gitattributes + +# CI +.codeclimate.yml +.travis.yml +.taskcluster.yml + +# Docker +docker-compose.yml +Dockerfile +.docker +.dockerignore + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +data/ +logs/ +saved_models/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Virtual environment +.env +.venv/ +venv/ + +# PyCharm +.idea + +# Python mode for VIM +.ropeproject +**/.ropeproject + +# Vim swap files +**/*.swp + +# VS Code +.vscode/ \ No newline at end of file diff --git a/.github/workflows/deploy-to-ecs.yml b/.github/workflows/deploy-to-ecs.yml new file mode 100644 index 0000000..ebcdfcd --- /dev/null +++ b/.github/workflows/deploy-to-ecs.yml @@ -0,0 +1,60 @@ +name: Build & Deploy to ECS + +on: + push: + branches: + - dev + +env: + AWS_REGION: ${{ vars.AWS_REGION }} + ECR_REPOSITORY: ${{ vars.ECR_REPOSITORY }} + ECS_CLUSTER: ${{ vars.ECS_CLUSTER }} + ECS_SERVICE: ${{ vars.ECS_SERVICE }} + CONTAINER_NAME: ${{ vars.CONTAINER_NAME }} + ECS_TASK_DEF: ${{ vars.ECS_TASK_DEF }} + +jobs: + deploy: + runs-on: ubuntu-latest + environment: production + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v1 + + - name: Build & Push Docker Image + id: build-image + env: + ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} + IMAGE_TAG: ${{ github.sha }} + run: | + docker build -f docker/Dockerfile.backend -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . + docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG + echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT + + - name: Render Updated ECS Task Definition + id: render-task-def + uses: aws-actions/amazon-ecs-render-task-definition@v1 + with: + task-definition-family: ${{ env.ECS_TASK_DEF }} + container-name: ${{ env.CONTAINER_NAME }} + image: ${{ steps.build-image.outputs.image }} + + - name: Deploy to Amazon ECS + uses: aws-actions/amazon-ecs-deploy-task-definition@v1 + with: + task-definition: ${{ steps.render-task-def.outputs.task-definition }} + service: ${{ env.ECS_SERVICE }} + cluster: ${{ env.ECS_CLUSTER }} + wait-for-service-stability: true diff --git a/README.md b/README.md index 6d80416..7d415b2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@ -# KNOWFLOW +# KnowFlow +--- -![Gemini_Generated_Image_j5414ij5414ij541](https://github.com/user-attachments/assets/032a9bc6-b0fb-437d-93cd-a9f063220e03) +# Deployment +--- +diagram-export-7-20-2025-6_03_33-PM + +# AI Infra +--- +diagram-export-7-20-2025-6_57_03-PM KnowFlow is a powerful hybrid Retrieval-Augmented Generation (RAG) system that combines semantic search with knowledge graph capabilities for intelligent document processing and querying. diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 3cda427..e4d8064 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -13,7 +13,7 @@ RUN sh /uv-installer.sh && rm /uv-installer.sh ENV PATH="/root/.local/bin/:$PATH" -COPY .env requirements.txt pyproject.toml uv.lock .python-version ./ +COPY requirements.txt pyproject.toml uv.lock .python-version ./ RUN uv sync --locked diff --git a/docker/Dockerfile.neo4j b/docker/Dockerfile.neo4j index bf72468..be2abd5 100644 --- a/docker/Dockerfile.neo4j +++ b/docker/Dockerfile.neo4j @@ -3,6 +3,7 @@ FROM neo4j:5.15.0 ENV NEO4J_AUTH=neo4j/Pstm!tr0ae#123 ENV NEO4J_PLUGINS=["apoc"] ENV NEO4J_dbms_security_procedures_unrestricted=apoc.* +ENV NEO4J_dbms_default__listen__address=0.0.0.0 ENV NEO4J_dbms_memory_heap_initial__size=512m ENV NEO4J_dbms_memory_heap_max__size=2G diff --git a/docs/FLOWCHART.md b/docs/FLOWCHART.md new file mode 100644 index 0000000..3f3dd00 --- /dev/null +++ b/docs/FLOWCHART.md @@ -0,0 +1,267 @@ +# KnowFlow Technical Architecture + +## Document Processing Pipeline + +### 1. Document Upload & Initial Processing + +- **Supported Formats**: PDF, DOCX, CSV, TXT +- **Document Loaders**: + - PDF: `PyMuPDFLoader` + - DOCX: `Docx2txtLoader` + - CSV: `CSVLoader` + - TXT: `TextLoader` + - Fallback: `UnstructuredFileLoader` + +### 2. Text Processing + +- **Text Splitter**: `RecursiveCharacterTextSplitter` + ```python + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + separators=["\n\n", "\n", " ", ""], + keep_separator=True + ) + ``` + +### 3. Vector Generation + +- **Embedding Model**: Gemini Embedding Model +- **Vector Store**: PGVector (PostgreSQL) + ```python + vector_store = PGVector( + connection=DATABASE_URL, + embeddings=embeddings, + collection_name=VECTOR_COLLECTION_NAME + ) + ``` + +### 4. Knowledge Graph Construction + +- **Database**: Neo4j +- **Node Types**: + - Document: Main document node + - Section: Document sections + - Entity: Named entities + - Concept: Key ideas/terms + - Tag: Categories +- **Relationship Types**: + - CONTAINS: Hierarchy + - RELATED_TO: General connections + - MENTIONS: References + - HAS_TAG: Classifications + +## Query Processing Pipeline + +### 1. Query Analysis + +- **Query Decomposition**: + ```python + messages = [ + SystemMessage(content="Query decomposition prompt"), + HumanMessage(content=query) + ] + sub_questions = llm.invoke(messages) + ``` + +### 2. Retrieval Process + +- **Vector Search**: + ```python + docs_and_scores = vector_store.similarity_search_with_score_by_vector( + embedding=query_embedding, + k=TOP_K_RESULTS, + filter=filter_dict + ) + ``` +- **Graph Search**: + - Cypher query generation + - Pattern matching + - Context retrieval + +### 3. Retrieval Evaluation + +- **Quality Metrics**: + ```json + { + "chunk_scores": [ + {"chunk": "text", "relevance_score": 0-10, "reasoning": "explanation"} + ], + "missing_aspects": ["list of missing information"], + "redundant_information": ["list of redundancies"], + "suggested_improvements": { + "additional_info_needed": ["missing info"], + "alternative_search_terms": ["suggested terms"] + }, + "overall_quality_score": 0-10 + } + ``` + +### 4. Response Generation + +- **Context Assembly**: + + ```python + def _merge_results(vector_results, graph_results): + graph_texts = [] + for result in graph_results: + text = f"Type: {result.get('type')}\n" + text += f"Properties: {result.get('properties')}\n" + text += f"Relationships: {result.get('relationships')}" + graph_texts.append(text) + + all_texts = vector_results[:3] + graph_texts[:3] + return "\n\n".join(all_texts) + ``` + +- **LLM Prompt Structure**: + ```python + system_prompt = f""" + You are a helpful, reasoning assistant. Answer based on: + Context: {context} + Guidelines: + - Rephrase/summarize from context + - Use reasoning to clarify + - No fabrication + - Indicate if answer unavailable + """ + ``` + +## Storage Architecture + +### 1. Vector Store (PostgreSQL) + +- **Schema**: + ```sql + CREATE TABLE document_chunks ( + id SERIAL PRIMARY KEY, + document_id INTEGER, + chunk_index INTEGER, + content TEXT, + embedding_vector vector(768), + metadata JSONB + ); + ``` + +### 2. Graph Database (Neo4j) + +- **Node Properties**: + ```json + { + "id": "unique_id", + "type": "node_type", + "name": "node_name", + "content": "text_content", + "created_at": "timestamp" + } + ``` +- **Relationship Properties**: + ```json + { + "type": "relationship_type", + "context": "relationship_context", + "confidence": 0.0-1.0, + "created_at": "timestamp" + } + ``` + +### 3. Document Storage (S3) + +- **Structure**: + ``` + s3://bucket/ + └── users/ + └── {user_id}/ + └── documents/ + └── {doc_id}.{extension} + ``` + +## Mermaid Technical Flowchart + +```mermaid +flowchart TB + %% Document Processing Pipeline + subgraph Document_Processing["Document Processing Pipeline"] + direction TB + Upload["Document Upload
Supported: PDF, DOCX, CSV, TXT"] + + subgraph Processing["Document Processing"] + Loader["Document Loaders
PyMuPDF/Docx2txt/CSV/Text"] + Splitter["Text Splitter
RecursiveCharacterTextSplitter
chunk_size=1000, overlap=200"] + VectorGen["Vector Generation
Gemini Embedding Model"] + end + + subgraph Storage["Storage Layer"] + S3Store["S3 Storage
Raw Documents"] + PGVector["PGVector Store
- document_id
- chunk_index
- embedding_vector
- metadata"] + Neo4j["Neo4j Graph DB
Node Types: Document, Section,
Entity, Concept, Tag
Relations: CONTAINS, RELATED_TO,
MENTIONS, HAS_TAG"] + end + + Upload --> Loader + Loader --> Splitter + Splitter --> VectorGen + VectorGen --> PGVector + Upload --> S3Store + Splitter --> Neo4j + end + + %% Query Processing Pipeline + subgraph Query_Processing["Query Processing Pipeline"] + direction TB + Query["User Query"] + + subgraph Query_Analysis["Query Analysis"] + QDecomp["Query Decomposition
Complex Query → Sub-questions
Using Gemini Pro"] + QEmbed["Query Embedding
Gemini Embedding Model"] + end + + subgraph Retrieval["Retrieval Layer"] + VecSearch["Vector Search
Similarity Search with Scores
TOP_K=5"] + GraphSearch["Graph Search
Cypher Query Generation
Pattern Matching"] + end + + subgraph Evaluation["Retrieval Evaluation"] + RelCheck["Relevance Check
Score: 0-10"] + ImproveSuggestions["Improvement Suggestions
- Missing Aspects
- Alternative Terms"] + end + + Query --> QDecomp + QDecomp --> QEmbed + QEmbed --> VecSearch + Query --> GraphSearch + VecSearch --> RelCheck + GraphSearch --> RelCheck + RelCheck --> ImproveSuggestions + ImproveSuggestions -.-> QEmbed + end + + %% Response Generation + subgraph Response_Gen["Response Generation"] + direction TB + Context["Context Assembly
Vector + Graph Results"] + LLMPrompt["LLM Prompt Construction
System + Context + Query"] + Response["Response Generation
Gemini Pro"] + + Context --> LLMPrompt + LLMPrompt --> Response + end + + %% Data Connections + PGVector --> VecSearch + Neo4j --> GraphSearch + VecSearch --> Context + GraphSearch --> Context + + %% External Services + subgraph External["External Services"] + direction LR + Gemini["Google Gemini Pro
- Chat Completion
- Embeddings"] + AWS["AWS S3
Document Storage"] + end + + %% Service Connections + VectorGen --> Gemini + QEmbed --> Gemini + Response --> Gemini + S3Store --> AWS +``` diff --git a/ecs/task-definition.json b/ecs/task-definition.json new file mode 100644 index 0000000..c40bf2b --- /dev/null +++ b/ecs/task-definition.json @@ -0,0 +1,94 @@ +{ + "taskDefinition": { + "taskDefinitionArn": "arn:aws:ecs:ap-south-1:953685791553:task-definition/knowflow-backend-task:1", + "containerDefinitions": [ + { + "name": "knowflow-backend", + "image": "953685791553.dkr.ecr.ap-south-1.amazonaws.com/knowflow-backend:latest", + "cpu": 0, + "portMappings": [ + { + "containerPort": 8000, + "hostPort": 8000, + "protocol": "tcp", + "name": "knowflow-backend-8000-tcp", + "appProtocol": "http" + } + ], + "essential": true, + "secrets": [ + { + "name": "AWS_ACCESS_KEY_ID", + "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_ACCESS_KEY_ID::" + }, + { + "name": "AWS_SECRET_ACCESS_KEY", + "valueFrom": "arn:aws:secretsmanager:ap-south-1:953685791553:secret:knowflow/app-secrets:AWS_SECRET_ACCESS_KEY::" + } + ], + "environmentFiles": [], + "mountPoints": [], + "volumesFrom": [], + "ulimits": [], + "healthCheck": { + "command": [ + "CMD-SHELL", + "curl -f http://localhost:8000/health || exit 1" + ], + "interval": 10, + "timeout": 5, + "retries": 3 + }, + "systemControls": [] + } + ], + "family": "knowflow-backend-task", + "taskRoleArn": "arn:aws:iam::953685791553:role/ECSTaskRole", + "executionRoleArn": "arn:aws:iam::953685791553:role/ecsTaskExecutionRole", + "networkMode": "awsvpc", + "revision": 1, + "volumes": [], + "status": "ACTIVE", + "requiresAttributes": [ + { + "name": "com.amazonaws.ecs.capability.docker-remote-api.1.24" + }, + { + "name": "com.amazonaws.ecs.capability.ecr-auth" + }, + { + "name": "com.amazonaws.ecs.capability.task-iam-role" + }, + { + "name": "ecs.capability.container-health-check" + }, + { + "name": "ecs.capability.execution-role-ecr-pull" + }, + { + "name": "com.amazonaws.ecs.capability.docker-remote-api.1.18" + }, + { + "name": "ecs.capability.task-eni" + } + ], + "placementConstraints": [], + "compatibilities": [ + "EC2", + "FARGATE" + ], + "runtimePlatform": { + "cpuArchitecture": "X86_64", + "operatingSystemFamily": "LINUX" + }, + "requiresCompatibilities": [ + "FARGATE" + ], + "cpu": "2048", + "memory": "8192", + "registeredAt": "2025-07-20T08:14:39.290000+05:30", + "registeredBy": "arn:aws:iam::953685791553:root", + "enableFaultInjection": false + }, + "tags": [] +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7b97a7b..64fe16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,8 @@ dependencies = [ "constructs>=10.4.2", "psycopg2-binary>=2.9.10", "pymupdf>=1.26.3", + "pytest>=8.4.1", ] + +[dependency-groups] +dev = [] diff --git a/src/core/auth.py b/src/core/auth.py index 73083d1..d144f00 100644 --- a/src/core/auth.py +++ b/src/core/auth.py @@ -1,11 +1,11 @@ from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import Depends from fastapi.security import OAuth2PasswordBearer from sqlalchemy.orm import Session from src.core.database import get_db -from src.services.auth_service import AuthService from src.models.database import User +from src.services.auth_service import AuthService oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token") diff --git a/src/core/config.py b/src/core/config.py index 5a59fad..275bcaa 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,8 +1,8 @@ +import json +import boto3 from functools import lru_cache from pathlib import Path from typing import Optional, List -import json -import boto3 from botocore.exceptions import ClientError from pydantic import Field from pydantic_settings import BaseSettings @@ -65,12 +65,14 @@ class Settings(BaseSettings): # AI Models GOOGLE_API_KEY: str = Field(default="") - GEMINI_EMBEDDING_MODEL: str = Field(default="", env="GEMINI_EMBEDDING_MODEL") - GEMINI_MODEL_NAME: str = Field(default="", env="GEMINI_MODEL_NAME") + GEMINI_EMBEDDING_MODEL: str = Field( + default="models/embedding-001", env="GEMINI_EMBEDDING_MODEL" + ) + GEMINI_MODEL_NAME: str = Field(default="gemini-2.0-flash", env="GEMINI_MODEL_NAME") # Vector Store VECTOR_COLLECTION_NAME: str = Field( - default="knowflow_vecotr_db", env="VECTOR_COLLECTION_NAME" + default="knowflow_vector_db", env="VECTOR_COLLECTION_NAME" ) CHUNK_SIZE: int = Field(default=700) CHUNK_OVERLAP: int = Field(default=80) diff --git a/src/core/database.py b/src/core/database.py index a22dc1c..62ec7d0 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,9 +1,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, declarative_base -from neo4j import GraphDatabase from src.core.config import settings -from src.core.logging import logger Base = declarative_base() diff --git a/src/core/exceptions.py b/src/core/exceptions.py index d607ad9..0d1efcd 100644 --- a/src/core/exceptions.py +++ b/src/core/exceptions.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, Optional from fastapi import status +from typing import Any, Dict, Optional class AppException(Exception): diff --git a/src/core/logging.py b/src/core/logging.py index 21c2e5d..93f2f54 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,6 +1,6 @@ +import sys import logging import logging.handlers -import sys from pathlib import Path from typing import Optional diff --git a/src/core/middleware.py b/src/core/middleware.py index 0b80f14..cb71bf4 100644 --- a/src/core/middleware.py +++ b/src/core/middleware.py @@ -1,13 +1,10 @@ -from datetime import datetime, timezone import time import uuid -from typing import Callable, Optional - +from datetime import datetime, timezone from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.types import Message from src.core.config import settings from src.core.logging import logger diff --git a/src/main.py b/src/main.py index dd758f2..a2a2fce 100644 --- a/src/main.py +++ b/src/main.py @@ -2,15 +2,15 @@ from fastapi import FastAPI from src.core.config import settings +from src.core.middleware import setup_middleware +from src.core.database import init_db +from src.core.logging import logger from src.routes import ( auth_routes, chat_routes, document_routes, session_routes, ) -from src.core.middleware import setup_middleware -from src.core.database import init_db -from src.core.logging import logger @asynccontextmanager diff --git a/src/models/database.py b/src/models/database.py index fa37ec7..de1b2b6 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -1,9 +1,10 @@ +import enum +import sqlalchemy from datetime import datetime, timezone from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Text, Enum from sqlalchemy.orm import relationship -import enum -import sqlalchemy from sqlalchemy.sql import func + from src.core.database import Base diff --git a/src/models/graph.py b/src/models/graph.py index f306c6d..cba5f4d 100644 --- a/src/models/graph.py +++ b/src/models/graph.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Optional +from typing import List, Optional from pydantic import BaseModel, Field, field_validator diff --git a/src/models/request.py b/src/models/request.py index cd76d38..cacff9d 100644 --- a/src/models/request.py +++ b/src/models/request.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, EmailStr, Field, constr +from pydantic import BaseModel, EmailStr, Field from typing import Optional, Dict, Any, List from datetime import datetime diff --git a/src/models/response.py b/src/models/response.py index bbd2f67..1ee337c 100644 --- a/src/models/response.py +++ b/src/models/response.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field + from src.models.database import DocumentStatus diff --git a/src/routes/auth_routes.py b/src/routes/auth_routes.py index 681ab22..18830b6 100644 --- a/src/routes/auth_routes.py +++ b/src/routes/auth_routes.py @@ -1,11 +1,10 @@ from datetime import timedelta from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from sqlalchemy.orm import Session +from fastapi.security import OAuth2PasswordRequestForm +from src.core.auth import get_current_user, get_auth_service from src.core.config import settings -from src.core.database import get_db from src.models.request import UserLogin, UserRegister from src.models.response import ( TokenResponse, @@ -13,9 +12,8 @@ RegisterResponse, MessageResponse, ) -from src.services.auth_service import AuthService from src.models.database import User -from src.core.auth import get_current_user, get_auth_service +from src.services.auth_service import AuthService router = APIRouter() diff --git a/src/routes/chat_routes.py b/src/routes/chat_routes.py index 3466b8c..632526f 100644 --- a/src/routes/chat_routes.py +++ b/src/routes/chat_routes.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from src.services.chat.chat_service import ChatService + +from src.core.auth import get_current_user +from src.core.database import get_db from src.core.exceptions import ExternalServiceException from src.core.logging import logger from src.models.request import ChatRequest, FollowUpChatRequest, RenameChatRequest @@ -10,9 +12,8 @@ RenameChatResponse, DeleteChatResponse, ) -from src.core.auth import get_current_user from src.models.database import User -from src.core.database import get_db +from src.services.chat.chat_service import ChatService from src.services.session_service import SessionService diff --git a/src/routes/document_routes.py b/src/routes/document_routes.py index d6730ad..b6e3d7f 100644 --- a/src/routes/document_routes.py +++ b/src/routes/document_routes.py @@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, UploadFile, BackgroundTasks, Query from src.core.auth import get_current_user -from src.services.document_service import DocumentService from src.models.database import User from src.models.request import DocumentIndexRequest from src.models.response import ( DocumentIndexResponse, MultiDocumentUploadResponse, ) +from src.services.document_service import DocumentService router = APIRouter() diff --git a/src/routes/session_routes.py b/src/routes/session_routes.py index 3e401f3..1db1dd1 100644 --- a/src/routes/session_routes.py +++ b/src/routes/session_routes.py @@ -1,17 +1,16 @@ -from typing import List, Optional +from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from src.core.database import get_db +from src.core.auth import get_current_user +from src.models.database import User from src.models.request import CreateSessionRequest, SendMessageRequest from src.models.response import ( ChatSessionResponse, ChatSessionListResponse, - MessageResponse, ) from src.services.session_service import SessionService -from src.core.auth import get_current_user -from src.models.database import User router = APIRouter() diff --git a/src/services/auth_service.py b/src/services/auth_service.py index c9d1957..1b5e547 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -7,7 +7,7 @@ from fastapi.security import OAuth2PasswordBearer from src.core.config import settings -from src.models.database import User, Document, DocumentChunk, DocumentStatus +from src.models.database import User pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") diff --git a/src/services/base_client.py b/src/services/base_client.py index eb5d187..fdbea09 100644 --- a/src/services/base_client.py +++ b/src/services/base_client.py @@ -1,4 +1,6 @@ -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings +from langchain_postgres import PGVector + from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger @@ -13,6 +15,18 @@ def __init__(self, service_name: str): model=settings.GEMINI_MODEL_NAME, convert_system_message_to_human=True, ) + + self.embeddings = GoogleGenerativeAIEmbeddings( + model=settings.GEMINI_EMBEDDING_MODEL, + google_api_key=settings.GOOGLE_API_KEY, + ) + + self.vector_store = PGVector( + connection=settings.DATABASE_URL, + embeddings=self.embeddings, + collection_name=settings.VECTOR_COLLECTION_NAME, + ) + logger.info(f"{service_name} initialized successfully") except Exception as e: logger.error( diff --git a/src/services/chat/chat_service.py b/src/services/chat/chat_service.py index 0a5b3cc..3295951 100644 --- a/src/services/chat/chat_service.py +++ b/src/services/chat/chat_service.py @@ -1,31 +1,23 @@ -import json from typing import List, Dict, Any, Optional from fastapi import HTTPException, status -from langchain_google_genai import GoogleGenerativeAIEmbeddings -from langchain_postgres import PGVector -from langchain_google_genai import ChatGoogleGenerativeAI from langchain.schema import HumanMessage, SystemMessage from sqlalchemy.orm import Session from datetime import datetime, timezone from neo4j import GraphDatabase +from src.core.database import get_db from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger +from src.models.request import FollowUpChatRequest +from src.models.response import FollowUpChatResponse +from src.models.database import ChatSession +from src.models.database import Message from src.services.graph_service import GraphService from src.services.auth_service import AuthService from src.services.base_client import BaseLLMClient from src.services.chat.query_decomposition import QueryDecompositionService from src.services.chat.retrieval_evaluation import RetrievalEvaluationService -from src.models.database import Document -from src.models.database import DocumentChunk -from src.models.database import DocumentStatus -from src.models.database import Message -from src.core.database import get_db -from src.models.request import FollowUpChatRequest -from src.models.response import FollowUpChatResponse -from src.models.database import ChatSession -from src.utils.utils import clean_llm_response class ChatService(BaseLLMClient): @@ -35,17 +27,6 @@ def __init__(self, db: Session = None): self.db = db or next(get_db()) self.auth_service = AuthService(self.db) - self.embeddings = GoogleGenerativeAIEmbeddings( - model=settings.GEMINI_EMBEDDING_MODEL, - google_api_key=settings.GOOGLE_API_KEY, - ) - - self.vector_store = PGVector( - connection=settings.DATABASE_URL, - embeddings=self.embeddings, - collection_name=settings.VECTOR_COLLECTION_NAME, - ) - self.driver = GraphDatabase.driver( settings.NEO4J_URI, auth=(settings.NEO4J_USER, settings.NEO4J_PASSWORD) ) diff --git a/src/services/document_service.py b/src/services/document_service.py index 98c6605..9fb66eb 100644 --- a/src/services/document_service.py +++ b/src/services/document_service.py @@ -1,3 +1,4 @@ +import tempfile import os import asyncio from typing import List, Optional, Dict, Any @@ -13,21 +14,19 @@ Docx2txtLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -import tempfile from src.core.database import get_db -from src.models.database import Document, DocumentStatus, DocumentChunk, User -from src.services.s3_service import S3Service from src.core.config import settings from src.core.exceptions import ExternalServiceException from src.core.logging import logger -from langchain_google_genai import GoogleGenerativeAIEmbeddings -from langchain_postgres import PGVector +from src.models.database import Document, DocumentStatus, DocumentChunk, User +from src.services.s3_service import S3Service from src.services.graph_service import GraphService +from src.services.base_client import BaseLLMClient from src.utils.utils import clean_whitespaes -class DocumentService: +class DocumentService(BaseLLMClient): SUPPORTED_MIMETYPES = { "application/pdf": ".pdf", "application/msword": ".doc", @@ -41,9 +40,7 @@ class DocumentService: def __init__( self, db: Session = next(get_db()), current_user: Optional[User] = None ): - self.db = db - self.storage_service = S3Service() - self.current_user = current_user + super().__init__("DocumentService") try: try: loop = asyncio.get_event_loop() @@ -51,16 +48,9 @@ def __init__( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - self.embeddings = GoogleGenerativeAIEmbeddings( - model=settings.GEMINI_EMBEDDING_MODEL, - google_api_key=settings.GOOGLE_API_KEY, - ) - - self.vector_store = PGVector( - connection=settings.DATABASE_URL, - embeddings=self.embeddings, - collection_name=settings.VECTOR_COLLECTION_NAME, - ) + self.db = db + self.storage_service = S3Service() + self.current_user = current_user self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=settings.CHUNK_SIZE, diff --git a/src/services/graph_service.py b/src/services/graph_service.py index d1ca178..9ca5d44 100644 --- a/src/services/graph_service.py +++ b/src/services/graph_service.py @@ -1,17 +1,17 @@ -from typing import List, Dict, Any +import re +import uuid import json +from datetime import datetime +from typing import List, Dict, Any from neo4j import GraphDatabase, Session from langchain.schema import HumanMessage, SystemMessage from src.core.config import settings -from src.models.graph import GraphKnowledge from src.core.exceptions import ExternalServiceException from src.core.logging import logger -from src.utils.utils import clean_llm_response +from src.models.graph import GraphKnowledge from src.services.base_client import BaseLLMClient -from datetime import datetime -import re -import uuid +from src.utils.utils import clean_llm_response class GraphService(BaseLLMClient): diff --git a/src/services/s3_service.py b/src/services/s3_service.py index 25b36a5..752a38b 100644 --- a/src/services/s3_service.py +++ b/src/services/s3_service.py @@ -1,7 +1,7 @@ import boto3 from botocore.exceptions import ClientError from fastapi import HTTPException, status -from typing import BinaryIO, Optional, Dict, Any, List +from typing import Optional, Dict, Any, List from concurrent.futures import ThreadPoolExecutor from src.core.config import settings diff --git a/src/services/session_service.py b/src/services/session_service.py index 0fda8bd..8de0f34 100644 --- a/src/services/session_service.py +++ b/src/services/session_service.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session, joinedload import uuid -from src.models.database import ChatSession, Message, User +from src.models.database import ChatSession, Message class SessionService: diff --git a/uv.lock b/uv.lock index 8c69be0..119e55c 100644 --- a/uv.lock +++ b/uv.lock @@ -919,6 +919,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -1002,6 +1011,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pymupdf" }, { name = "pypdf2" }, + { name = "pytest" }, { name = "python-jose" }, { name = "python-magic" }, { name = "uvicorn" }, @@ -1027,11 +1037,15 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pymupdf", specifier = ">=1.26.3" }, { name = "pypdf2", specifier = ">=3.0.1" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-jose", specifier = ">=3.5.0" }, { name = "python-magic", specifier = ">=0.4.27" }, { name = "uvicorn", specifier = ">=0.35.0" }, ] +[package.metadata.requires-dev] +dev = [] + [[package]] name = "langchain" version = "0.3.26" @@ -1491,6 +1505,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/81/f457d6d361e04d061bef413749a6e1ab04d98cfeec6d8abcfe40184750f3/pgvector-0.3.6-py3-none-any.whl", hash = "sha256:f6c269b3c110ccb7496bac87202148ed18f34b390a0189c783e351062400a75a", size = 24880 }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, +] + [[package]] name = "propcache" version = "0.3.2" @@ -1869,6 +1892,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/5e/c86a5643653825d3c913719e788e41386bee415c2b87b4f955432f2de6b2/pypdf2-3.0.1-py3-none-any.whl", hash = "sha256:d16e4205cfee272fbdc0568b68d82be796540b1537508cef59388f839c191928", size = 232572 }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"