Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ TEST_DATABASE_URL=postgresql+asyncpg://postgres:postgres@db_test:5432/test_db
SECRET_KEY=insecure_secret_key_for_ci_and_dev
LLM_SERVICE_URL=http://llm-service:8001/api/v1/ask
MLOPS_SERVICE_URL=http://mlops-service:8002/api/v1/trigger_dag
ENCRYPTION_KEY="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="
ENCRYPTION_KEY="AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="
REDIS_URL=redis://redis:6379/0
CACHE_TTL=3600
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
run: cp .env.example .env || touch .env

- name: Start containers
run: docker compose up -d --build --wait backend db db_test
run: docker compose up -d --build --wait backend db db_test redis

- name: Run Tests
run: docker compose exec -T backend pytest -v
Expand Down
14 changes: 14 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,24 @@ services:
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy

volumes:
- ./:/app

redis:
image: redis:7-alpine
container_name: gitlab-redis
restart: always
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5

volumes:
postgres_data:
driver: local
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"greenlet>=3.2.4",
"pytest-mock>=3.15.1",
"dishka>=1.7.2",
"redis>=7.1.0",
]

[tool.ruff]
Expand Down
2 changes: 2 additions & 0 deletions src/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def get_chat_history(
)
async def send(
chat_id: UUID4,
repo_ids: List[UUID4],
message: MessageCreate,
service: FromDishka[ChatService],
current_user: User = Depends(get_current_user)
Expand All @@ -80,5 +81,6 @@ async def send(
return await service.ask_question(
user_id=current_user.id,
chat_id=chat_id,
repo_ids=repo_ids,
question=message.content
)
2 changes: 1 addition & 1 deletion src/api/schemas/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Repository(BaseModel):
id: UUID
name: str
path_with_namespace: str = Field(..., example="group/my-awesome-project")
web_url: HttpUrl
url: HttpUrl


class SyncRequest(BaseModel): # разделить на два разных типа - одиночный и multi?
Expand Down
76 changes: 53 additions & 23 deletions src/application/services/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import uuid
from datetime import datetime
from typing import List, Optional

from fastapi import HTTPException, status
from pydantic import UUID4

from src.domain.models.chat import Chat, Message
from src.domain.models.chat import Chat, Message, MessageRole, Source
from src.domain.repositories.cache_repo import ICacheRepository
from src.domain.repositories.chat_repo import IChatRepository


Expand All @@ -13,9 +16,11 @@ class ChatService:

def __init__(
self,
chat_repo: IChatRepository
chat_repo: IChatRepository,
cache_repo: ICacheRepository
):
self.chat_repo = chat_repo
self.cache_repo = cache_repo

async def create_chat(self, owner_id: UUID4, title: str) -> Chat:
"""Create a new chat with specified title for user by their id."""
Expand All @@ -32,6 +37,20 @@ async def get_chat_history(self, user_id: UUID4, chat_id: UUID4) -> Optional[Cha
2. Return chat history by its id.
"""
chat = await self.chat_repo.get_chat_full(chat_id)
await self._validate_chat_access(
user_id=user_id,
chat_id=chat_id,
chat=chat
)

return chat

async def _validate_chat_access(
self,
user_id: UUID4,
chat_id: UUID4,
chat: Optional[Chat] = None,
) -> Optional[Chat]:
if not chat:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -44,12 +63,11 @@ async def get_chat_history(self, user_id: UUID4, chat_id: UUID4) -> Optional[Cha
detail=f"User {user_id} doesn't have access to the chat {chat_id}."
)

return chat

async def ask_question(
self,
user_id: UUID4,
chat_id: UUID4,
repo_ids: List[UUID4],
question: str
) -> Message:
"""QnA iteration.
Expand All @@ -60,16 +78,37 @@ async def ask_question(
4. Save RAG answer.
"""
chat = await self.chat_repo.get_chat_full(chat_id)
if not chat:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Chat {chat_id} not found"
await self._validate_chat_access(
user_id=user_id,
chat_id=chat_id,
chat=chat
)

cache_key = self.cache_repo.construct_cache_key(
query=question,
repository_ids=repo_ids
)

message = await self.cache_repo.get_cached_value(cache_key)

if message is None:
message = Message(
id=uuid.uuid4(),
role=MessageRole.ASSISTANT,
content="Will be a real llm call later :)",
created_at=datetime.now(),
sources=[
Source(
title="README.md",
url="https://gitlab/mock_project/readme/",
quote="Mock quote"
)
]
)

if chat.owner_id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user_id} doesn't have access to the chat {chat_id}."
await self.cache_repo.put_cache_value(
key=cache_key,
message=message
)

await self.chat_repo.add_message(
Expand All @@ -78,20 +117,11 @@ async def ask_question(
content=question
)

mock_answer = "Will be a real llm call later :)"
mock_sources = [
{
"title": "README.md",
"url": "http://gitlab/mock_project/readme",
"quote": "Mock quote"
}
]

assistant_message = await self.chat_repo.add_message(
chat_id=chat_id,
role="assistant",
content=mock_answer,
sources=mock_sources
content=message.content,
sources=[source.model_dump(mode="json") for source in message.sources]
)

return assistant_message
2 changes: 2 additions & 0 deletions src/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class Settings(BaseSettings):

DATABASE_URL: SecretStr
TEST_DATABASE_URL: SecretStr
REDIS_URL: SecretStr

SECRET_KEY: SecretStr
ENCRYPTION_KEY: SecretStr
MLOPS_SERVICE_URL: SecretStr
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
CACHE_TTL: int = 300 # sec

model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")

Expand Down
2 changes: 2 additions & 0 deletions src/domain/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class Source(BaseModel):

"""Data structure for source."""

model_config = ConfigDict(from_attributes=True)

title: str
url: HttpUrl
quote: str
Expand Down
25 changes: 25 additions & 0 deletions src/domain/repositories/cache_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from uuid import UUID

from src.domain.models.chat import Message


class ICacheRepository(ABC):

"""Class sets the contract by which Application-layer connects with Infrastructure-layer."""

@abstractmethod
def construct_cache_key(self, query: str, repository_ids: List[UUID]) -> str:
"""Construct cache key."""
raise NotImplementedError

@abstractmethod
async def get_cached_value(self, key: str) -> Optional[Message]:
"""Get value from cache by given key. Return None if key isn't in cache yet."""
raise NotImplementedError

@abstractmethod
async def put_cache_value(self, key: str, message: Message) -> None:
"""Put value to key with given key."""
raise NotImplementedError
Empty file.
46 changes: 46 additions & 0 deletions src/infrastructure/cache/repositories/redis_cache_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from typing import List, Optional
from uuid import UUID

from redis.asyncio import Redis

from src.core.settings import settings
from src.domain.models.chat import Message
from src.domain.repositories.cache_repo import ICacheRepository


class RedisCacheRepository(ICacheRepository):

"""Cache's repository realisation for Redis."""

def __init__(self, redis_client: Redis):
self.redis = redis_client

def construct_cache_key(self, query: str, repository_ids: List[UUID]) -> str:
"""Construct cache key."""
repository_ids = [str(repository_id) for repository_id in repository_ids]
sorted_ids = sorted(repository_ids)

return f"{';'.join(sorted_ids)}:{query}"

async def get_cached_value(self, key: str) -> Optional[Message]:
"""Get value from cache by given key. Return None if key isn't in cache yet."""
value = await self.redis.get(key)

if not value:
return None

try:
return Message.model_validate_json(value)
except json.JSONDecodeError as error:
raise ValueError("Invalid JSON format from cache.") from error

async def put_cache_value(self, key: str, message: Message) -> None:
"""Put value to key with given key."""
json_data = message.model_dump_json()

await self.redis.set(
name=key,
value=json_data,
ex=settings.CACHE_TTL
)
41 changes: 39 additions & 2 deletions src/infrastructure/di/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dishka import Provider, Scope, provide
from fastapi import HTTPException, status
from redis import ConnectionPool, Redis, asyncio as aioredis
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Expand All @@ -15,6 +16,8 @@
IRoleRepository,
IUserRepository,
)
from src.domain.repositories.cache_repo import ICacheRepository
from src.infrastructure.cache.repositories.redis_cache_repo import RedisCacheRepository
from src.infrastructure.db.repositories import (
SqlAlchemyChatRepository,
SqlAlchemyGitLabRepository,
Expand Down Expand Up @@ -114,9 +117,13 @@ def get_auth_service(self, user_repo: IUserRepository) -> AuthService:
return AuthService(user_repo=user_repo)

@provide
def get_chat_service(self, chat_repo: IChatRepository) -> ChatService:
def get_chat_service(
self,
chat_repo: IChatRepository,
cache_repo: ICacheRepository
) -> ChatService:
"""Get chat's service."""
return ChatService(chat_repo=chat_repo)
return ChatService(chat_repo=chat_repo, cache_repo=cache_repo)

@provide
def get_index_service(
Expand All @@ -135,3 +142,33 @@ def get_admin_service(
) -> AdminService:
"""Get admin service."""
return AdminService(user_repo=user_repo, role_repo=role_repo)


class CacheProvider(Provider):

"""Provider for cache."""

@provide(scope=Scope.APP)
def get_redis_pool(self) -> ConnectionPool:
"""Get Redis connection pool."""
return aioredis.ConnectionPool.from_url(
settings.REDIS_URL.get_secret_value(),
decode_responses=True
)

@provide(scope=Scope.APP)
async def get_redis_client(self, pool: ConnectionPool) -> AsyncIterable[Redis]:
"""Get Redis client."""
client = aioredis.Redis(connection_pool=pool)
try:
yield client
finally:
await client.close()

@provide(scope=Scope.REQUEST)
async def get_cache_repository(
self,
client: Redis
) -> ICacheRepository:
"""Get Redis cache repository."""
return RedisCacheRepository(client)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from api.routers import admin, auth, chat, indexing, repository
from src.infrastructure.di.providers import (
CacheProvider,
InfrastructureProvider,
RepositoryProvider,
SericeProvider,
Expand All @@ -18,6 +19,7 @@
InfrastructureProvider(),
RepositoryProvider(),
SericeProvider(),
CacheProvider()
)

setup_dishka(container, app)
Expand Down
Loading
Loading