Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions at_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from atproto import Client, IdResolver, client_utils
from atproto import Client, IdResolver
from atproto_client.models.app.bsky.feed.post import ReplyRef
from atproto_client.models.com.atproto.repo.strong_ref import Main
from atproto_client.utils import TextBuilder

from config import ACCOUNT_HANDLE, ACCOUNT_PASSWORD
from exceptions import DidResolveException, HandleResolveException
Expand All @@ -13,14 +14,15 @@ def __init__(self) -> None:
self.account_did = self.client.me.did
self.id_resolver = IdResolver()

def post_reply(self, post: client_utils.TextBuilder, parent_cid: str, parent_uri: str) -> None:
parent = Main(cid=parent_cid, uri=parent_uri)
reply_to = ReplyRef(parent=parent, root=parent)
def post_reply(self, post: TextBuilder, cid: str, parent_uri: str, root_uri: str) -> None:
parent = Main(cid=cid, uri=parent_uri)
root = Main(cid=cid, uri=root_uri)
reply_to = ReplyRef(parent=parent, root=root)
self.client.send_post(post, reply_to=reply_to)

@staticmethod
def build_mention_post(handle: str, did: str, text: str) -> client_utils.TextBuilder:
post = client_utils.TextBuilder()
def build_mention_post(handle: str, did: str, text: str) -> TextBuilder:
post = TextBuilder()
post.mention(f"@{handle}", did)
post.text(text)
return post
Expand Down
12 changes: 6 additions & 6 deletions error_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from logging import warning

from atproto_core.cid import CIDType as CID

from at_client import AtClient


Expand All @@ -12,14 +10,16 @@ class ErrorHandler:
def __init__(self, at_client: AtClient):
self.at_client = at_client

def handle_no_run_at(self, did: str, parent_cid: CID, parent_uri: str) -> None:
def handle_no_run_at(self, did: str, parent_cid: str, parent_uri: str, root_uri: str) -> None:
warning(f"No run at was parsed for post at URI: {parent_uri}")
handle = self.at_client.resolve_handle(did)
post = AtClient.build_mention_post(handle, did, self.NO_RUN_MSG)
self.at_client.post_reply(post, str(parent_cid), parent_uri)
self.at_client.post_reply(post, parent_cid, parent_uri, root_uri)

def handle_run_at_in_past(self, did: str, parent_cid: CID, parent_uri: str) -> None:
def handle_run_at_in_past(
self, did: str, parent_cid: str, parent_uri: str, root_uri: str
) -> None:
warning(f"Run at was parsed to be in the past for post at URI: {parent_uri}")
handle = self.at_client.resolve_handle(did)
post = AtClient.build_mention_post(handle, did, self.PAST_MSG)
self.at_client.post_reply(post, str(parent_cid), parent_uri)
self.at_client.post_reply(post, parent_cid, parent_uri, root_uri)
45 changes: 29 additions & 16 deletions mention_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Any

from atproto import CAR, AtUri, models
from atproto_client.models.app.bsky.feed.post import ReplyRef
from atproto_client.models.app.bsky.richtext.facet import Mention
from atproto_client.models.com.atproto.sync.subscribe_repos import Commit
from atproto_core.cid import CIDType as CID
from atproto_firehose import FirehoseSubscribeReposClient, parse_subscribe_repos_message
from atproto_firehose.models import MessageFrame

Expand All @@ -24,7 +24,7 @@ def __init__(self, at_client: AtClient, error_handler: ErrorHandler):
self.at_client = at_client
self.error_handler = error_handler

def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, CID | None] | None:
def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, str, ReplyRef] | None:
car = CAR.from_bytes(commit.blocks)
for op in commit.ops:
if op.action != "create" or not op.cid:
Expand All @@ -36,21 +36,31 @@ def parse_create_op(self, commit: Commit) -> tuple[str, AtUri, CID | None] | Non
if not blocks:
continue
record = models.get_or_create(blocks, strict=False)
if not record.facets:
continue
for facet in record.facets:
for feature in facet.features:
if isinstance(feature, Mention) and feature.did == self.at_client.account_did:
return record.text, uri, op.cid
if record.facets:
for facet in record.facets:
for feature in facet.features:
if (
isinstance(feature, Mention)
and feature.did == self.at_client.account_did
):
return record.text, uri, str(op.cid), record.reply
reply = getattr(record, "reply", None)
if reply:
parent_did = AtUri.from_str(reply.parent.uri).hostname
if parent_did == self.at_client.account_did:
return record.text, uri, str(op.cid), record.reply
return None

def enqueue_reminder(self, did: str, run_at: datetime, post_cid: str, post_uri: str) -> None:
def enqueue_reminder(
self, did: str, run_at: datetime, cid: str, parent_uri: str, root_uri: str
) -> None:
handle = self.at_client.resolve_handle(did)
task = {
"cid": cid,
"did": did,
"handle": handle,
"post_cid": post_cid,
"post_uri": post_uri,
"parent_uri": parent_uri,
"root_uri": root_uri,
}
redis.zadd("task_queue", {dumps(task): run_at.timestamp()})

Expand All @@ -70,14 +80,17 @@ def handle_firehose_event(self, message_frame: MessageFrame) -> None:
result = self.parse_create_op(commit)
if not result:
return
message, uri, cid = result
post_uri = f"at://{commit.repo}/app.bsky.feed.post/{uri.rkey}"
message, uri, cid, reply = result
parent_uri = (
reply.parent.uri if reply else f"at://{commit.repo}/app.bsky.feed.post/{uri.rkey}"
)
root_uri = reply.root.uri if reply else parent_uri
run_at = self.parse_run_at(message)
if not run_at:
return self.error_handler.handle_no_run_at(commit.repo, cid, post_uri)
return self.error_handler.handle_no_run_at(commit.repo, cid, parent_uri, root_uri)
if run_at <= datetime.now():
return self.error_handler.handle_run_at_in_past(commit.repo, cid, post_uri)
self.enqueue_reminder(commit.repo, run_at, str(cid), post_uri)
return self.error_handler.handle_run_at_in_past(commit.repo, cid, parent_uri, root_uri)
self.enqueue_reminder(commit.repo, run_at, str(cid), parent_uri, root_uri)

def run(self, stop_event: Event) -> None:
client = FirehoseSubscribeReposClient()
Expand Down
12 changes: 7 additions & 5 deletions scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
class Task(TypedDict):
did: str
handle: str
post_cid: str
post_uri: str
cid: str
parent_uri: str
root_uri: str


class Scheduler:
Expand All @@ -21,10 +22,11 @@ def __init__(self, at_client: AtClient):
def run_task(self, task: Task) -> None:
handle = task["handle"]
did = task["did"]
parent_cid = task["post_cid"]
parent_uri = task["post_uri"]
cid = task["cid"]
parent_uri = task["parent_uri"]
root_uri = task["root_uri"]
post = AtClient.build_mention_post(handle, did, ", your reminder is ready!")
self.at_client.post_reply(post, parent_cid, parent_uri)
self.at_client.post_reply(post, cid, parent_uri, root_uri)

def run(self, stop_event: Event) -> None:
while not stop_event.is_set():
Expand Down