From b191cb13d413c20af7e1ad8ee42e0b86ad1c0a6e Mon Sep 17 00:00:00 2001 From: imbeer Date: Thu, 29 May 2025 13:01:12 +0300 Subject: [PATCH 1/2] #131: refactor and move suggestion service to another directory --- backend/backend/urls.py | 3 +-- .../service.py} | 7 +++++ backend/suggestion/urls.py | 7 +++++ .../views.py} | 7 +---- .../taskbench/services/statistics_service.py | 17 ++++++------ backend/taskbench/tests/test_suggestion.py | 17 ++++++------ backend/taskbench/views/statistics_views.py | 26 +++++-------------- 7 files changed, 40 insertions(+), 44 deletions(-) rename backend/{taskbench/services/suggestion_service.py => suggestion/service.py} (94%) create mode 100644 backend/suggestion/urls.py rename backend/{taskbench/views/suggestion_views.py => suggestion/views.py} (91%) diff --git a/backend/backend/urls.py b/backend/backend/urls.py index af48171b..ca23e00f 100644 --- a/backend/backend/urls.py +++ b/backend/backend/urls.py @@ -6,7 +6,6 @@ SubtaskCreateView, SubtaskDetailView ) -from taskbench.views.suggestion_views import SuggestionView from taskbench.views.task_views import ( TaskListView, TaskDetailView, @@ -23,6 +22,7 @@ urlpatterns = [ path("", include("dashboard.urls")), path("", include("subscription.urls")), + path("", include("suggestion.urls")), path('tasks/', TaskListView.as_view(), name='task_list'), path('tasks//', TaskDetailView.as_view(), name='task_detail'), path('subtasks/', SubtaskCreateView.as_view(), name='subtask_create'), @@ -35,6 +35,5 @@ path('user/password/', ChangePasswordView.as_view(), name='change_password'), path('token/refresh/', TokenRefreshView.as_view(), name="token_refresh"), path('statistics/', StatisticsView.as_view(), name='statistics'), - path('ai/suggestions/', SuggestionView.as_view(), name="ai_suggestions"), ] diff --git a/backend/taskbench/services/suggestion_service.py b/backend/suggestion/service.py similarity index 94% rename from backend/taskbench/services/suggestion_service.py rename to backend/suggestion/service.py index e9c2b73a..beb41b7f 100644 --- a/backend/taskbench/services/suggestion_service.py +++ b/backend/suggestion/service.py @@ -15,6 +15,13 @@ logger = logging.getLogger(__name__) +SUBTASK_SYSTEM_PROMPT = """ +Ты — специализированный декомпозитор задач. +Твоя ЕДИНСТВЕННАЯ функция — анализировать пользовательскую задачу и разбивать её на элементарные подзадачи. +Каждая подзадача должна +""" + + @singleton class SuggestionService: diff --git a/backend/suggestion/urls.py b/backend/suggestion/urls.py new file mode 100644 index 00000000..e8e3f8f5 --- /dev/null +++ b/backend/suggestion/urls.py @@ -0,0 +1,7 @@ +from django.urls import path + +from suggestion.views import SuggestionView + +urlpatterns = [ + path('ai/suggestions/', SuggestionView.as_view(), name="ai_suggestions"), +] diff --git a/backend/taskbench/views/suggestion_views.py b/backend/suggestion/views.py similarity index 91% rename from backend/taskbench/views/suggestion_views.py rename to backend/suggestion/views.py index 2beac158..125d31ca 100644 --- a/backend/taskbench/views/suggestion_views.py +++ b/backend/suggestion/views.py @@ -1,19 +1,14 @@ import json from django.http import JsonResponse -from rest_framework import status from rest_framework.response import Response from rest_framework.views import APIView -from backend import settings -from backend.settings import DEBUG +from suggestion.service import SuggestionService from taskbench.models.models import Category from taskbench.serializers.task_serializers import TaskDPCtoFlatSerializer from taskbench.serializers.user_serializers import JwtSerializer from taskbench.services.user_service import get_token -# from taskbench.serializers.task_serializers import TaskSerializer -from taskbench.services.suggestion_service import SuggestionService - class SuggestionView(APIView): diff --git a/backend/taskbench/services/statistics_service.py b/backend/taskbench/services/statistics_service.py index 868121d6..c333c4fd 100644 --- a/backend/taskbench/services/statistics_service.py +++ b/backend/taskbench/services/statistics_service.py @@ -1,12 +1,15 @@ import logging from datetime import timedelta -from django.utils import timezone + from django.db.models import Count +from django.utils import timezone + from taskbench.models.models import Task -from taskbench.utils.exceptions import AuthenticationError +from taskbench.services.user_service import get_user logger = logging.getLogger(__name__) + def get_statistics(token): """ Возвращает статистику продуктивности для пользователя: @@ -14,12 +17,8 @@ def get_statistics(token): - max_done: максимальное количество задач за день в текущей неделе - weekly: массив из 7 значений (float 0.0-1.0) с понедельника по воскресенье """ - from taskbench.services.user_service import get_user - try: - user = get_user(token) - except AuthenticationError as e: - logger.error(f"Authentication failed: {str(e)}") - raise + + user = get_user(token) # Определяем начало текущей недели (понедельник) today = timezone.now().date() @@ -70,4 +69,4 @@ def get_statistics(token): 'done_today': done_today, 'max_done': max_done, 'weekly': weekly - } \ No newline at end of file + } diff --git a/backend/taskbench/tests/test_suggestion.py b/backend/taskbench/tests/test_suggestion.py index 9bdc50b1..236d6af9 100644 --- a/backend/taskbench/tests/test_suggestion.py +++ b/backend/taskbench/tests/test_suggestion.py @@ -1,12 +1,14 @@ -from datetime import datetime, timezone, UTC +from datetime import datetime, timezone + from django.test import SimpleTestCase, TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient from rest_framework_simplejwt.tokens import RefreshToken +from suggestion.service import SuggestionService from taskbench.models.models import User, Category -from ..services.suggestion_service import SuggestionService -from rest_framework.test import APIClient -from django.urls import reverse -from rest_framework import status + class SuggestionServiceTestCase(SimpleTestCase): def __init__(self, method_name: str = "runTest"): @@ -19,7 +21,7 @@ def setUp(self): def test_deadline_suggestion(self): text = 'Не забыть, что завтра в 3 часа дня созвон' now_time = datetime(2025, 4, 24, 12, 00, 0).replace(tzinfo=None) - supposed_time = datetime(2025,4,25,15,00,0).replace(tzinfo=None) + supposed_time = datetime(2025, 4, 25, 15, 00, 0).replace(tzinfo=None) result = SuggestionService().suggest_deadline(text, now=now_time) print(result) self.assertEqual(result, supposed_time) @@ -102,7 +104,7 @@ def test_suggestion_api(self): "priority": None, "deadline": None, }, - "title": "Не забыть, что завтра в 3 часа дня созвон", + "title": "Не забыть, что завтра в 3 часа дня созвон", "timestamp": now_time, }, HTTP_AUTHORIZATION=f'Bearer {token}', @@ -118,4 +120,3 @@ def test_suggestion_api(self): # self.assertIsNotNone(category_id) self.assertIsNotNone(deadline) self.assertTrue(0 <= priority <= 1) - diff --git a/backend/taskbench/views/statistics_views.py b/backend/taskbench/views/statistics_views.py index 2bce8bdf..6783e8a5 100644 --- a/backend/taskbench/views/statistics_views.py +++ b/backend/taskbench/views/statistics_views.py @@ -1,23 +1,11 @@ -from django.utils import timezone -from datetime import datetime, timedelta -from django.db.models import Count -from rest_framework import status -from rest_framework.views import APIView -from rest_framework.response import Response - -from ..models.models import Task -from ..serializers.statistics_serializers import StatisticsSerializer -from ..serializers.user_serializers import JwtSerializer -from ..services.suggestion_service import logger -from ..services.user_service import get_token - - from django.http import JsonResponse -from rest_framework.views import APIView from rest_framework.exceptions import ValidationError +from rest_framework.views import APIView + +from taskbench.serializers.statistics_serializers import statistics_response from taskbench.services.statistics_service import get_statistics from taskbench.services.user_service import get_token, AuthenticationError -from taskbench.serializers.statistics_serializers import statistics_response + class StatisticsView(APIView): """ @@ -33,11 +21,11 @@ def get(self, request, *args, **kwargs): statistics = get_statistics(token) return statistics_response(statistics) except AuthenticationError as e: - logger.error(f"Authentication error: {str(e)}") + # logger.error(f"Authentication error: {str(e)}") return JsonResponse({'error': str(e)}, status=401) except ValidationError as e: - logger.error(f"Validation error: {str(e)}") + # logger.error(f"Validation error: {str(e)}") return JsonResponse({'error': str(e)}, status=400) except Exception as e: - logger.error(f"Unexpected error in StatisticsView: {str(e)}", exc_info=True) + # logger.error(f"Unexpected error in StatisticsView: {str(e)}", exc_info=True) return JsonResponse({'error': str(e)}, status=500) \ No newline at end of file From f84e87905b5624facfd3a71dda1567f9506d34b0 Mon Sep 17 00:00:00 2001 From: imbeer Date: Thu, 29 May 2025 23:48:50 +0300 Subject: [PATCH 2/2] #131: improve system prompts and add subscription check --- backend/suggestion/service.py | 173 +++++++++++++++++++-- backend/suggestion/views.py | 46 +----- backend/taskbench/tests/test_suggestion.py | 91 ++++++++--- 3 files changed, 230 insertions(+), 80 deletions(-) diff --git a/backend/suggestion/service.py b/backend/suggestion/service.py index beb41b7f..14ce000b 100644 --- a/backend/suggestion/service.py +++ b/backend/suggestion/service.py @@ -8,7 +8,13 @@ import dateparser.search from gigachat import GigaChat +from gigachat.models import Chat, Messages, MessagesRole +from pydantic import ValidationError +from subscription.service import is_user_subscribed +from taskbench.models.models import Category +from taskbench.serializers.task_serializers import TaskDPCtoFlatSerializer +from taskbench.services.user_service import get_user from taskbench.utils.decorators import singleton GIGACHAT_API_SAFETY_GAP = 60 @@ -16,16 +22,102 @@ logger = logging.getLogger(__name__) SUBTASK_SYSTEM_PROMPT = """ -Ты — специализированный декомпозитор задач. -Твоя ЕДИНСТВЕННАЯ функция — анализировать пользовательскую задачу и разбивать её на элементарные подзадачи. -Каждая подзадача должна +Разбей введенную пользователем задачу на несколько более мелких подзадач, состоящие не более чем из четырех слов. +Каждая подзадача должна быть короткой и представлена на отдельной строке. +Не используй знаки препинания или обозначения списка, просто пиши только подзадачи с новой строки. """ +SUBTASK_SYSTEM_PROMPT_V2 = """ +Предложи несколько мелких подзадач к введенной пользователем задаче, состоящих не более чем из четырех слов. +Каждая подзадача должна быть короткой и представлена на отдельной строке. +Не пиши заголовок. +Не пиши нумерацию подзадачи, пиши ТОЛЬКО текст подзадач с новой строки. +""" + +TIME_SYSTEM_PROMPT = """ +Предложи предположительную дату и время, соответствующие сроку введенной пользователем задачи. +Если время указано относительно, например словами 'завтра' или 'в следующую среду', в качестве точки отсчета используй текущее время. +В приоритете всегда предполагай время из будущего. +Если и время и дата не указаны, отправь ТОЛЬКО ОДИН СИМВОЛ: "-". +Если известно только время, считай датой сегодня. +Если известна только дата, считай время таким же как сейчас. +Отправь ТОЛЬКО дату и время в формате YYYY:MM:DD hh:mm. Не добавляй никакого другого текста или пояснений. +Например: 2024:03:15 10:30 +""" + +CATEGORY_SYSTEM_PROMPT = """ +Соотнеси пользовательский текст с одной из категорий, соответствующих следующему списку. Напиши только одно слово - название категории. Список категорий:\n +""" + + +def get_subtask_prompt(): + return SUBTASK_SYSTEM_PROMPT_V2 + + +def get_time_system_prompt(user_datetime): + return TIME_SYSTEM_PROMPT + "\nТекущее время (точка отсчета): " + user_datetime.isoformat(timespec='minutes') + + +def get_category_system_prompt(category_names: list): + return CATEGORY_SYSTEM_PROMPT + ', '.join(category_names) + + +def suggest(token, data): + """ + + :param token: + :param data: + :return: subtasks, category names (list), category_id, deadline + """ + + user = get_user(token) + serializer = TaskDPCtoFlatSerializer(data=data) + if not serializer.is_valid(): + return ValidationError(serializer.errors) + input_data = serializer.validated_data + deadline = input_data.get('deadline') + title = input_data.get('title') + category_id = input_data.get('category_id') + timestamp = input_data.get('timestamp') + + service = SuggestionService(debug=False) + + subscribed = is_user_subscribed(user) + # subscribed = True + + if deadline is None: + deadline = service.suggest_deadline_local( + title, now=timestamp) if not subscribed else service.suggest_deadline( + title, now=timestamp) + + """ + Проверка пользователя на подписку. + """ + if not subscribed: + return None, None, None, deadline + + if category_id is None: + categories = Category.objects.filter(user=user) + category_names = [c.name for c in categories] + category_index = service.suggest_category(title, category_names) + category_name = '' + if category_index < 0 or category_index >= len(categories): + category_id = None + else: + category_id = categories[category_index].category_id + category_name = categories[category_index].name + else: + category_name = Category.objects.get(category_id=category_id).name + + subtasks = service.suggest_subtasks(title) + + return subtasks, category_name, category_id, deadline + @singleton class SuggestionService: - def __init__(self, debug:bool=False): + def __init__(self, debug: bool = False): self.giga = GigaChat( credentials=os.getenv('GIGACHAT_AUTH_KEY'), verify_ssl_certs=False @@ -55,6 +147,27 @@ def update_token(self): self.access_token = response.access_token self.expires_at = datetime.fromtimestamp(response.expires_at / 1000, tz=timezone.utc) + def send_message_with_system_prompt(self, system_prompt: str, user_text: str): + if self.debug: return None + + self.update_token() + result = self.giga.chat( + Chat( + messages=[ + Messages( + role=MessagesRole.SYSTEM, + content=system_prompt + ), + Messages( + role=MessagesRole.USER, + content=user_text + ) + ] + ) + ) + print(result.choices[0].message.content) + return result + def suggest_subtasks(self, text: str) -> list: """ Предлагает подзадачи. @@ -63,13 +176,12 @@ def suggest_subtasks(self, text: str) -> list: if self.debug: return ["1. Начать делать задачу", "2. Продолжить делать задачу", "3. Закончить делать задачу"] self.update_token() - payload = 'Разбей данную задачу на список максимально коротких подзадач. Каждый элемент начинай с новой строки без иных символов и нумерации. ' + text - result = self.giga.chat(payload) + result = self.send_message_with_system_prompt(get_subtask_prompt(), text) subtasks = [ match.group(1).strip().lower() for line in result.choices[0].message.content.split('\n') - if (match := re.match(r'^(?:\d+\.\s*|-\s*)?([^.]+)(?:\.?)$', line.strip()))] + if (match := re.match(r'^(?:\d+\.\s*|-\s*)?([^.]+)\.?$', line.strip()))] return subtasks @@ -86,10 +198,10 @@ def suggest_category(self, text: str, category_names: list) -> int | None: if self.debug: return 0 - # names = [c.name for c in categories] self.update_token() - payload = "Выбери из списка категорию, которая больше всего подходит тексту. Напиши только выбранное. Список:" +', '.join(category_names) + " Текст:" + text - result = self.giga.chat(payload).choices[0].message.content + result = self.send_message_with_system_prompt( + get_category_system_prompt(category_names), + text).choices[0].message.content for i in range(len(category_names)): if self._equal_ignore_space_case(category_names[i], result): @@ -115,7 +227,42 @@ def suggest_priority(self, text: str) -> int: def suggest_deadline(self, text: str, *, now: datetime | None = None) -> datetime | None: """ - Анализирует текст с естественным языком и ищет даты. + Анализирует текст с использованием gigachat и ищет даты. + :param text: анализируемый текст + :param now: время, которое считается за текущее. + """ + + if self.debug: + return self.suggest_deadline_local(text, now=now) + + now = now or datetime.now().replace(tzinfo=None) + + result = self.send_message_with_system_prompt( + get_time_system_prompt(now), + text).choices[0].message.content + + cleaned_text = result.strip() + if cleaned_text == "-": + return self.suggest_deadline_local(text, now=now) + + expected_format = "%Y:%m:%d %H:%M" + + try: + dt_object = datetime.strptime(cleaned_text, expected_format) + print(dt_object.isoformat()) + return dt_object + except ValueError: + local_suggest = self.suggest_deadline_local(cleaned_text, now=now) + if local_suggest is not None: + print(local_suggest.isoformat()) + return local_suggest + local_suggest = self.suggest_deadline_local(text, now=now) + print(local_suggest.isoformat()) + return local_suggest + + def suggest_deadline_local(self, text: str, *, now: datetime | None = None) -> datetime | None: + """ + Анализирует текст локально с естественным языком и ищет даты. Выбирает либо последнюю из прошедших дат, либо ближайшую из будущих. :param text: анализируемый текст :param now: время, которое считается за текущее. @@ -135,7 +282,6 @@ def suggest_deadline(self, text: str, *, now: datetime | None = None) -> datetim if not found: return None - # found -> список кортежей (фрагмент, datetime) datetimes = [dt for _, dt in found] future = [d.replace(tzinfo=None) for d in datetimes if d.replace(tzinfo=None) > now.replace(tzinfo=None)] @@ -152,7 +298,6 @@ def _equal_ignore_space_case(a: Union[str, bytes], b: Union[str, bytes]) -> bool if isinstance(b, bytes): b = b.decode() - # убираем всё, что считается пробельным в Unicode (\s = [ \t\n\r\f\v] + другие) normalize = lambda s: re.sub(r'\s+', '', s).casefold() - return normalize(a) == normalize(b) \ No newline at end of file + return normalize(a) == normalize(b) diff --git a/backend/suggestion/views.py b/backend/suggestion/views.py index 125d31ca..30cd637f 100644 --- a/backend/suggestion/views.py +++ b/backend/suggestion/views.py @@ -1,55 +1,19 @@ import json from django.http import JsonResponse -from rest_framework.response import Response from rest_framework.views import APIView -from suggestion.service import SuggestionService -from taskbench.models.models import Category -from taskbench.serializers.task_serializers import TaskDPCtoFlatSerializer -from taskbench.serializers.user_serializers import JwtSerializer +from suggestion.service import suggest from taskbench.services.user_service import get_token class SuggestionView(APIView): def post(self, request): - data = json.loads(request.body) - serializer = TaskDPCtoFlatSerializer(data=data) - if not serializer.is_valid(): - return JsonResponse(serializer.errors, status=400) - user_serializer = JwtSerializer(data=get_token(request)) - if not user_serializer.is_valid(): - return Response("Invalid token", status=401) - user_id = user_serializer.validated_data['user'].user_id - - input_data = serializer.validated_data - deadline = input_data.get('deadline') - title = input_data.get('title') - priority = input_data.get('priority') - category_id = input_data.get('category_id') - timestamp = input_data.get('timestamp') - service = SuggestionService(debug=False) - - if deadline is None: - deadline = service.suggest_deadline(title, now=timestamp) - # priority = service.suggest_priority(title) - - if category_id is None: - categories = Category.objects.filter(user_id = user_id) - category_names = [c.name for c in categories] - category_index = service.suggest_category(title, category_names) - category_name = '' - if category_index < 0 or category_index >= len(categories): - category_id = None - else: - category_id = categories[category_index].category_id - category_name = categories[category_index].name - else: - category_name = Category.objects.get(category_id=category_id).name + token = get_token(request) - subtasks = service.suggest_subtasks(title) + subtasks, category_name, category_id, deadline = suggest(token, data) return JsonResponse({ "suggested_dpc": { @@ -58,5 +22,5 @@ def post(self, request): "category_id": category_id if category_name is not None else '', "category_name": category_name if category_name is not None else '', }, - "suggestions": subtasks - }) \ No newline at end of file + "suggestions": subtasks if subtasks is not None else [], + }) diff --git a/backend/taskbench/tests/test_suggestion.py b/backend/taskbench/tests/test_suggestion.py index 236d6af9..926a70ca 100644 --- a/backend/taskbench/tests/test_suggestion.py +++ b/backend/taskbench/tests/test_suggestion.py @@ -7,7 +7,7 @@ from rest_framework_simplejwt.tokens import RefreshToken from suggestion.service import SuggestionService -from taskbench.models.models import User, Category +from taskbench.models.models import User, Category, Subscription class SuggestionServiceTestCase(SimpleTestCase): @@ -22,7 +22,7 @@ def test_deadline_suggestion(self): text = 'Не забыть, что завтра в 3 часа дня созвон' now_time = datetime(2025, 4, 24, 12, 00, 0).replace(tzinfo=None) supposed_time = datetime(2025, 4, 25, 15, 00, 0).replace(tzinfo=None) - result = SuggestionService().suggest_deadline(text, now=now_time) + result = SuggestionService().suggest_deadline_local(text, now=now_time) print(result) self.assertEqual(result, supposed_time) @@ -61,11 +61,24 @@ def setUp(self): user_id=1002, email='testuser1002@mail.com' ) + self.user.set_password('test_password') + self.user.save() + + self.user2 = User.objects.create( + user_id=145, + email='yet_another_user@mail.com' + ) + self.user2.set_password('test_password') + self.user2.save() + + self.sub = Subscription.objects.create( + user=self.user, + ) + self.sub.activate(0) self.access_token = RefreshToken.for_user(self.user).access_token + self.access_token2 = RefreshToken.for_user(self.user2).access_token - self.user.set_password('test_password') - self.user.save() category1 = Category.objects.create( category_id=1001, @@ -89,34 +102,62 @@ def setUp(self): def test_suggestion_api(self): url = reverse('login') - response = self.client.post(url, - data={ - "email": "testuser1002@mail.com", - "password": "test_password" - }, format='json') + response = self.client.post( + url, + data={ + "email": "testuser1002@mail.com", + "password": "test_password" + }, format='json') token = response.json().get('access') now_time = datetime(2025, 4, 24, 12, 00, 0, tzinfo=timezone.utc) url = reverse('ai_suggestions') - response = self.client.post(url, - data={ - "dpc": { - "category_id": None, - "priority": None, - "deadline": None, - }, - "title": "Не забыть, что завтра в 3 часа дня созвон", - "timestamp": now_time, - }, - HTTP_AUTHORIZATION=f'Bearer {token}', - format='json') + response = self.client.post( + url, + data={ + "dpc": { + "category_id": None, + "priority": None, + "deadline": None, + }, + "title": "Не забыть, что завтра в 3 часа дня созвон", + "timestamp": now_time, + }, + HTTP_AUTHORIZATION=f'Bearer {token}', + format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) subtasks = response.json().get('suggestions') self.assertTrue(len(subtasks) > 0) deadline = response.json().get('suggested_dpc').get('deadline') priority = int(response.json().get('suggested_dpc').get('priority')) - # category_id = int(response.json().get('suggested_dpc').get('category_id')) может быть null, если не подошла ни одна категория. - # category = response.json().get('suggested_dpc').get('category_name') - # self.assertIsNotNone(category) - # self.assertIsNotNone(category_id) self.assertIsNotNone(deadline) self.assertTrue(0 <= priority <= 1) + + def test_suggestion_unsubscribed_api(self): + url = reverse('login') + response = self.client.post( + url, + data={ + "email": "yet_another_user@mail.com", + "password": "test_password" + }, format='json') + token = response.json().get('access') + now_time = datetime(2025, 4, 24, 12, 00, 0, tzinfo=timezone.utc) + url = reverse('ai_suggestions') + response = self.client.post( + url, + data={ + "dpc": { + "category_id": None, + "priority": None, + "deadline": None, + }, + "title": "Не забыть, что завтра в 3 часа дня созвон", + "timestamp": now_time, + }, + HTTP_AUTHORIZATION=f'Bearer {token}', + format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + subtasks = response.json().get('suggestions') + self.assertTrue(len(subtasks) == 0) + deadline = response.json().get('suggested_dpc').get('deadline') + self.assertIsNotNone(deadline) \ No newline at end of file