diff --git a/at_client.py b/at_client.py index 8c0a5f0..35545e9 100644 --- a/at_client.py +++ b/at_client.py @@ -1,6 +1,7 @@ -from atproto import Client, IdResolver, client_utils +from atproto import Client, IdResolver from atproto_client.models.app.bsky.feed.post import ReplyRef from atproto_client.models.com.atproto.repo.strong_ref import Main +from atproto_client.utils import TextBuilder from config import ACCOUNT_HANDLE, ACCOUNT_PASSWORD from exceptions import DidResolveException, HandleResolveException @@ -13,14 +14,15 @@ def __init__(self) -> None: self.account_did = self.client.me.did self.id_resolver = IdResolver() - def post_reply(self, post: client_utils.TextBuilder, parent_cid: str, parent_uri: str) -> None: - parent = Main(cid=parent_cid, uri=parent_uri) - reply_to = ReplyRef(parent=parent, root=parent) + def post_reply(self, post: TextBuilder, cid: str, parent_uri: str, root_uri: str) -> None: + parent = Main(cid=cid, uri=parent_uri) + root = Main(cid=cid, uri=root_uri) + reply_to = ReplyRef(parent=parent, root=root) self.client.send_post(post, reply_to=reply_to) @staticmethod - def build_mention_post(handle: str, did: str, text: str) -> client_utils.TextBuilder: - post = client_utils.TextBuilder() + def build_mention_post(handle: str, did: str, text: str) -> TextBuilder: + post = TextBuilder() post.mention(f"@{handle}", did) post.text(text) return post diff --git a/error_handler.py b/error_handler.py index 99653d7..e9172d8 100644 --- a/error_handler.py +++ b/error_handler.py @@ -1,7 +1,5 @@ from logging import warning -from atproto_core.cid import CIDType as CID - from at_client import AtClient @@ -12,14 +10,16 @@ class ErrorHandler: def __init__(self, at_client: AtClient): self.at_client = at_client - def handle_no_run_at(self, did: str, parent_cid: CID, parent_uri: str) -> None: + def handle_no_run_at(self, did: str, parent_cid: str, parent_uri: str, root_uri: str) -> None: warning(f"No run at was parsed for post at URI: {parent_uri}") handle = self.at_client.resolve_handle(did) post = AtClient.build_mention_post(handle, did, self.NO_RUN_MSG) - self.at_client.post_reply(post, str(parent_cid), parent_uri) + self.at_client.post_reply(post, parent_cid, parent_uri, root_uri) - def handle_run_at_in_past(self, did: str, parent_cid: CID, parent_uri: str) -> None: + def handle_run_at_in_past( + self, did: str, parent_cid: str, parent_uri: str, root_uri: str + ) -> None: warning(f"Run at was parsed to be in the past for post at URI: {parent_uri}") handle = self.at_client.resolve_handle(did) post = AtClient.build_mention_post(handle, did, self.PAST_MSG) - self.at_client.post_reply(post, str(parent_cid), parent_uri) + self.at_client.post_reply(post, parent_cid, parent_uri, root_uri) diff --git a/mention_listener.py b/mention_listener.py index 64c72bf..7f72172 100644 --- a/mention_listener.py +++ b/mention_listener.py @@ -5,9 +5,9 @@ from typing import Any from atproto import CAR, AtUri, models +from atproto_client.models.app.bsky.feed.post import ReplyRef from atproto_client.models.app.bsky.richtext.facet import Mention from atproto_client.models.com.atproto.sync.subscribe_repos import Commit -from atproto_core.cid import CIDType as CID from atproto_firehose import FirehoseSubscribeReposClient, parse_subscribe_repos_message from atproto_firehose.models import MessageFrame @@ -24,7 +24,7 @@ def __init__(self, at_client: AtClient, error_handler: ErrorHandler): self.at_client = at_client self.error_handler = error_handler - def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, CID | None] | None: + def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, str, ReplyRef] | None: car = CAR.from_bytes(commit.blocks) for op in commit.ops: if op.action != "create" or not op.cid: @@ -36,21 +36,31 @@ def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, CID | None] | Non if not blocks: continue record = models.get_or_create(blocks, strict=False) - if not record.facets: - continue - for facet in record.facets: - for feature in facet.features: - if isinstance(feature, Mention) and feature.did == self.at_client.account_did: - return record.text, uri, op.cid + if record.facets: + for facet in record.facets: + for feature in facet.features: + if ( + isinstance(feature, Mention) + and feature.did == self.at_client.account_did + ): + return record.text, uri, str(op.cid), record.reply + reply = getattr(record, "reply", None) + if reply: + parent_did = AtUri.from_str(reply.parent.uri).hostname + if parent_did == self.at_client.account_did: + return record.text, uri, str(op.cid), record.reply return None - def enqueue_reminder(self, did: str, run_at: datetime, post_cid: str, post_uri: str) -> None: + def enqueue_reminder( + self, did: str, run_at: datetime, cid: str, parent_uri: str, root_uri: str + ) -> None: handle = self.at_client.resolve_handle(did) task = { + "cid": cid, "did": did, "handle": handle, - "post_cid": post_cid, - "post_uri": post_uri, + "parent_uri": parent_uri, + "root_uri": root_uri, } redis.zadd("task_queue", {dumps(task): run_at.timestamp()}) @@ -70,14 +80,17 @@ def handle_firehose_event(self, message_frame: MessageFrame) -> None: result = self.parse_create_op(commit) if not result: return - message, uri, cid = result - post_uri = f"at://{commit.repo}/app.bsky.feed.post/{uri.rkey}" + message, uri, cid, reply = result + parent_uri = ( + reply.parent.uri if reply else f"at://{commit.repo}/app.bsky.feed.post/{uri.rkey}" + ) + root_uri = reply.root.uri if reply else parent_uri run_at = self.parse_run_at(message) if not run_at: - return self.error_handler.handle_no_run_at(commit.repo, cid, post_uri) + return self.error_handler.handle_no_run_at(commit.repo, cid, parent_uri, root_uri) if run_at <= datetime.now(): - return self.error_handler.handle_run_at_in_past(commit.repo, cid, post_uri) - self.enqueue_reminder(commit.repo, run_at, str(cid), post_uri) + return self.error_handler.handle_run_at_in_past(commit.repo, cid, parent_uri, root_uri) + self.enqueue_reminder(commit.repo, run_at, str(cid), parent_uri, root_uri) def run(self, stop_event: Event) -> None: client = FirehoseSubscribeReposClient() diff --git a/scheduler.py b/scheduler.py index 348ba06..776b8f4 100644 --- a/scheduler.py +++ b/scheduler.py @@ -10,8 +10,9 @@ class Task(TypedDict): did: str handle: str - post_cid: str - post_uri: str + cid: str + parent_uri: str + root_uri: str class Scheduler: @@ -21,10 +22,11 @@ def __init__(self, at_client: AtClient): def run_task(self, task: Task) -> None: handle = task["handle"] did = task["did"] - parent_cid = task["post_cid"] - parent_uri = task["post_uri"] + cid = task["cid"] + parent_uri = task["parent_uri"] + root_uri = task["root_uri"] post = AtClient.build_mention_post(handle, did, ", your reminder is ready!") - self.at_client.post_reply(post, parent_cid, parent_uri) + self.at_client.post_reply(post, cid, parent_uri, root_uri) def run(self, stop_event: Event) -> None: while not stop_event.is_set():