diff --git a/backend/community_manager/actions/chat.py b/backend/community_manager/actions/chat.py index e808515..6da499c 100644 --- a/backend/community_manager/actions/chat.py +++ b/backend/community_manager/actions/chat.py @@ -1124,6 +1124,32 @@ def __init__(self, db_session: Session): self.authorization_action = AuthorizationAction(db_session) # self.bot_api_service = TelegramBotApiService() + async def check_chat_members_compliance(self, chat_id: int) -> int: + """ + Iterates over all members of a chat in batches and kicks ineligible members. + + :param chat_id: The ID of the chat to check. + :return: The total number of members processed. + """ + logger.info(f"Starting to check chat members for chat {chat_id=!r}.") + + total_processed = 0 + for chat_members_chunk in self.telegram_chat_user_service.yield_all_for_chat( + chat_id=chat_id, + batch_size=100, + ): + await self.kick_ineligible_chat_members(chat_members=chat_members_chunk) + total_processed += len(chat_members_chunk) + logger.info( + f"Processed chunk of {len(chat_members_chunk)} users for chat {chat_id=!r}. " + f"Total processed: {total_processed}" + ) + + logger.info( + f"Finished checking members for chat {chat_id=!r}. Total: {total_processed}" + ) + return total_processed + async def kick_chat_member(self, chat_member: TelegramChatUser) -> None: """ Kicks a specified chat member from the chat. It ensures that the bot diff --git a/backend/community_manager/tasks/chat.py b/backend/community_manager/tasks/chat.py index 872143a..7d7390e 100644 --- a/backend/community_manager/tasks/chat.py +++ b/backend/community_manager/tasks/chat.py @@ -44,11 +44,7 @@ async def check_target_chat_members(chat_id: int) -> None: with DBService().db_session() as db_session: # BotAPI does not need a telethon client action = CommunityManagerUserChatAction(db_session) - chat_members = action.telegram_chat_user_service.get_all( - chat_ids=[chat_id], with_wallet_details=True - ) - logger.info(f"Found {len(chat_members)} chat members for chat {chat_id=!r}.") - await action.kick_ineligible_chat_members(chat_members=chat_members) + await action.check_chat_members_compliance(chat_id=chat_id) @app.task( diff --git a/backend/core/src/core/services/chat/user.py b/backend/core/src/core/services/chat/user.py index bcf205c..3a27dd3 100644 --- a/backend/core/src/core/services/chat/user.py +++ b/backend/core/src/core/services/chat/user.py @@ -170,6 +170,37 @@ def get_all( return query.all() + def yield_all_for_chat( + self, chat_id: int, batch_size: int = 100 + ) -> Iterable[list[TelegramChatUser]]: + """ + Yields all users for a given chat in batches, using keyset pagination. + This is useful for processing large chats without loading all users into memory. + """ + last_seen_user_id = 0 + while True: + stmt = ( + select(TelegramChatUser) + .where( + TelegramChatUser.chat_id == chat_id, + TelegramChatUser.user_id > last_seen_user_id, + ) + .order_by(TelegramChatUser.user_id.asc()) + .limit(batch_size) + .options( + joinedload(TelegramChatUser.wallet_link).options( + joinedload(TelegramChatUserWallet.wallet), + ) + ) + ) + users = self.db_session.execute(stmt).scalars().unique().all() + + if not users: + break + + yield users + last_seen_user_id = users[-1].user_id + def get_all_by_linked_wallet(self, addresses: list[str]) -> list[TelegramChatUser]: query = self.db_session.query(TelegramChatUser) query = query.join( diff --git a/backend/tests/unit/core/services/chat/test_user.py b/backend/tests/unit/core/services/chat/test_user.py new file mode 100644 index 0000000..e17916d --- /dev/null +++ b/backend/tests/unit/core/services/chat/test_user.py @@ -0,0 +1,40 @@ +import pytest +from sqlalchemy.orm import Session + +from core.models.chat import TelegramChatUser +from core.services.chat.user import TelegramChatUserService +from tests.factories import TelegramChatFactory, TelegramChatUserFactory, UserFactory + + +@pytest.mark.asyncio +async def test_yield_all_for_chat_batching(db_session: Session) -> None: + # Setup + chat = TelegramChatFactory.with_session(db_session).create() + service = TelegramChatUserService(db_session) + + # Create 25 users + users = [] + for i in range(25): + user = UserFactory.with_session(db_session).create(telegram_id=1000 + i) + chat_user = TelegramChatUserFactory.with_session(db_session).create( + chat=chat, user=user, is_admin=False, is_managed=True + ) + users.append(chat_user) + + # Test + batches: list[list[TelegramChatUser]] = [] + for batch in service.yield_all_for_chat(chat.id, batch_size=10): + batches.append(batch) + + # Verify + assert len(batches) == 3 + assert len(batches[0]) == 10 + assert len(batches[1]) == 10 + assert len(batches[2]) == 5 + + all_yielded_users = [u for batch in batches for u in batch] + assert len(all_yielded_users) == 25 + + # Verify order + user_ids = [u.user_id for u in all_yielded_users] + assert user_ids == sorted(user_ids)