Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/tg_cli/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ def data_group():
@click.option("-f", "--format", "fmt", type=click.Choice(["text", "json", "yaml"]), default="text")
@click.option("-o", "--output", "output_file", help="Output file path")
@click.option("--hours", type=int, help="Only export last N hours")
def export(chat: str, fmt: str, output_file: str | None, hours: int | None):
@click.option("--topic", type=int, help="Filter by forum topic ID")
def export(chat: str, fmt: str, output_file: str | None, hours: int | None, topic: int | None):
"""Export messages from CHAT to text, JSON, or YAML."""
with MessageDB() as db:
chat_id = resolve_chat_id_or_print(db, chat)
if chat_id is None:
return

if hours:
msgs = db.get_recent(chat_id=chat_id, hours=hours, limit=100000)
msgs = db.get_recent(
chat_id=chat_id, hours=hours, limit=100000,
topic_id=topic,
)
else:
msgs = db.get_recent(chat_id=chat_id, hours=None, limit=100000)
msgs = db.get_recent(
chat_id=chat_id, hours=None, limit=100000,
topic_id=topic,
)

if not msgs:
structured_fmt = (
Expand Down
31 changes: 25 additions & 6 deletions src/tg_cli/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _maybe_sync_first(chat: str | None, sync_first: bool, sync_limit: int) -> No
@click.option("-c", "--chat", help="Filter by chat name")
@click.option("-s", "--sender", help="Filter by sender name")
@click.option("--hours", type=int, help="Only search messages within N hours")
@click.option("--topic", type=int, help="Filter by forum topic ID")
@click.option("--regex", "use_regex", is_flag=True, help="Treat KEYWORD as a regex pattern")
@click.option("--sync-first", is_flag=True, help="Refresh local cache before searching")
@click.option(
Expand All @@ -54,6 +55,7 @@ def search(
chat: str | None,
sender: str | None,
hours: int | None,
topic: int | None,
use_regex: bool,
sync_first: bool,
sync_limit: int,
Expand All @@ -73,7 +75,8 @@ def search(
try:
if use_regex:
results = db.search_regex(
keyword, chat_id=chat_id, sender=sender, hours=hours, limit=limit
keyword, chat_id=chat_id, sender=sender, hours=hours, limit=limit,
topic_id=topic,
)
else:
results = db.search(
Expand All @@ -82,6 +85,7 @@ def search(
sender=sender,
hours=hours,
limit=limit,
topic_id=topic,
)
except re.error as exc:
if emit_error("invalid_regex", f"Invalid regex pattern: {exc}"):
Expand Down Expand Up @@ -124,6 +128,7 @@ def search(
@click.option("-c", "--chat", help="Filter by chat name")
@click.option("-s", "--sender", help="Filter by sender name")
@click.option("--hours", type=int, default=24, show_default=True, help="Only show last N hours")
@click.option("--topic", type=int, help="Filter by forum topic ID")
@click.option(
"--sync-first",
is_flag=True,
Expand All @@ -141,6 +146,7 @@ def recent(
chat: str | None,
sender: str | None,
hours: int,
topic: int | None,
sync_first: bool,
sync_limit: int,
limit: int,
Expand All @@ -155,7 +161,10 @@ def recent(
chat_id = resolve_chat_id_or_print(db, chat)
if chat and chat_id is None:
return
msgs = db.get_recent(chat_id=chat_id, sender=sender, hours=hours, limit=limit)
msgs = db.get_recent(
chat_id=chat_id, sender=sender, hours=hours,
limit=limit, topic_id=topic,
)

if msgs and emit_structured(msgs, as_json=as_json, as_yaml=as_yaml):
return
Expand Down Expand Up @@ -340,6 +349,7 @@ def timeline(

@query_group.command("today")
@click.option("-c", "--chat", help="Filter by chat name")
@click.option("--topic", type=int, help="Filter by forum topic ID")
@click.option(
"--sync-first",
is_flag=True,
Expand All @@ -352,7 +362,11 @@ def timeline(
help="Max messages per chat when using --sync-first",
)
@structured_output_options
def today(chat: str | None, sync_first: bool, sync_limit: int, as_json: bool, as_yaml: bool):
def today(
chat: str | None, topic: int | None,
sync_first: bool, sync_limit: int,
as_json: bool, as_yaml: bool,
):
"""Show today's messages, grouped by chat."""
from datetime import datetime

Expand All @@ -362,7 +376,7 @@ def today(chat: str | None, sync_first: bool, sync_limit: int, as_json: bool, as
chat_id = resolve_chat_id_or_print(db, chat)
if chat and chat_id is None:
return
msgs = db.get_today(chat_id=chat_id)
msgs = db.get_today(chat_id=chat_id, topic_id=topic)
latest_ts = db.get_latest_timestamp(chat_id=chat_id)

if msgs and emit_structured(msgs, as_json=as_json, as_yaml=as_yaml):
Expand Down Expand Up @@ -403,6 +417,7 @@ def today(chat: str | None, sync_first: bool, sync_limit: int, as_json: bool, as
@click.argument("keywords")
@click.option("-c", "--chat", help="Filter by chat name")
@click.option("--hours", type=int, help="Only search last N hours (default: today)")
@click.option("--topic", type=int, help="Filter by forum topic ID")
@click.option("--sync-first", is_flag=True, help="Refresh local cache before filtering")
@click.option(
"--sync-limit",
Expand All @@ -415,6 +430,7 @@ def filter_msgs(
keywords: str,
chat: str | None,
hours: int | None,
topic: int | None,
sync_first: bool,
sync_limit: int,
as_json: bool,
Expand Down Expand Up @@ -444,9 +460,12 @@ def filter_msgs(
return

if hours:
msgs = db.get_recent(chat_id=chat_id, hours=hours, limit=100000)
msgs = db.get_recent(
chat_id=chat_id, hours=hours, limit=100000,
topic_id=topic,
)
else:
msgs = db.get_today(chat_id=chat_id)
msgs = db.get_today(chat_id=chat_id, topic_id=topic)

# Filter messages containing ANY of the keywords (case-insensitive)
pattern = re.compile("|".join(re.escape(k) for k in keyword_list), re.IGNORECASE)
Expand Down
23 changes: 23 additions & 0 deletions src/tg_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ async def fetch_history(
if ts and ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)

# Extract forum topic ID from reply_to
topic_id = None
if msg.reply_to:
# reply_to_top_id is the topic root message ID in forum groups
topic_id = getattr(msg.reply_to, 'reply_to_top_id', None)
# For top-level topic messages, forum_topic flag is set
if topic_id is None and getattr(msg.reply_to, 'forum_topic', False):
topic_id = getattr(msg.reply_to, 'reply_to_msg_id', None)
# General topic (id=1): messages may lack reply_to entirely
if topic_id is None and getattr(entity, 'forum', False):
topic_id = 1

batch.append(
dict(
chat_id=chat_id,
Expand All @@ -216,6 +228,7 @@ async def fetch_history(
sender_name=sender_name,
content=content,
timestamp=ts or datetime.now(timezone.utc),
topic_id=topic_id,
)
)

Expand Down Expand Up @@ -348,6 +361,15 @@ async def handler(event):
if ts and ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)

# Extract topic_id for listen path
topic_id = None
if msg.reply_to:
topic_id = getattr(msg.reply_to, 'reply_to_top_id', None)
if topic_id is None and getattr(msg.reply_to, 'forum_topic', False):
topic_id = getattr(msg.reply_to, 'reply_to_msg_id', None)
if topic_id is None and getattr(chat, 'forum', False):
topic_id = 1

db.insert_message(
chat_id=chat.id,
chat_name=chat_name,
Expand All @@ -356,6 +378,7 @@ async def handler(event):
sender_name=sender_name,
content=content,
timestamp=ts or datetime.now(timezone.utc),
topic_id=topic_id,
)

time_str = ts.strftime("%H:%M:%S") if ts else "??:??:??"
Expand Down
47 changes: 42 additions & 5 deletions src/tg_cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
content TEXT,
timestamp TEXT NOT NULL,
raw_json TEXT,
topic_id INTEGER,
UNIQUE(platform, chat_id, msg_id)
);
"""
Expand All @@ -34,6 +35,7 @@
CREATE INDEX IF NOT EXISTS idx_messages_chat_ts ON messages(chat_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_messages_content ON messages(content);
CREATE INDEX IF NOT EXISTS idx_messages_sender ON messages(sender_name);
CREATE INDEX IF NOT EXISTS idx_messages_topic ON messages(chat_id, topic_id);
"""


Expand Down Expand Up @@ -63,7 +65,21 @@ def __init__(self, db_path: Path | str | None = None):
self.conn = sqlite3.connect(str(self.db_path))
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA journal_mode=WAL")
self.conn.executescript(_CREATE_TABLE + _CREATE_INDEX)
self.conn.executescript(_CREATE_TABLE)
self._migrate()
self.conn.executescript(_CREATE_INDEX)

def _migrate(self) -> None:
"""Add columns introduced after the initial schema."""
cols = {
row[1]
for row in self.conn.execute("PRAGMA table_info(messages)").fetchall()
}
if "topic_id" not in cols:
self.conn.execute(
"ALTER TABLE messages ADD COLUMN topic_id INTEGER"
)
self.conn.commit()

def __enter__(self):
return self
Expand Down Expand Up @@ -114,6 +130,7 @@ def insert_message(
content: str | None,
timestamp: datetime,
raw_json: dict[str, Any] | None = None,
topic_id: int | None = None,
) -> bool:
"""Insert a message, returns True if inserted (not duplicate)."""
try:
Expand All @@ -128,9 +145,10 @@ def insert_message(
sender_name,
content,
timestamp,
raw_json
raw_json,
topic_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
platform,
chat_id,
Expand All @@ -141,6 +159,7 @@ def insert_message(
content,
timestamp.isoformat(),
json.dumps(raw_json, ensure_ascii=False) if raw_json else None,
topic_id,
),
)
self.conn.commit()
Expand Down Expand Up @@ -171,6 +190,7 @@ def insert_batch(self, messages: list[dict], platform: str = "telegram") -> int:
else m["timestamp"]
),
json.dumps(m["raw_json"], ensure_ascii=False) if m.get("raw_json") else None,
m.get("topic_id"),
)
for m in messages
]
Expand All @@ -187,9 +207,10 @@ def insert_batch(self, messages: list[dict], platform: str = "telegram") -> int:
sender_name,
content,
timestamp,
raw_json
raw_json,
topic_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
rows,
)
self.conn.commit()
Expand All @@ -205,6 +226,7 @@ def search(
sender: str | None = None,
hours: int | None = None,
limit: int = 50,
topic_id: int | None = None,
) -> list[dict]:
"""Search messages by keyword."""
query = "SELECT * FROM messages WHERE content LIKE ?"
Expand All @@ -219,6 +241,9 @@ def search(
cutoff = (datetime.now(timezone.utc) - timedelta(hours=hours)).isoformat()
query += " AND timestamp >= ?"
params.append(cutoff)
if topic_id is not None:
query += " AND topic_id = ?"
params.append(topic_id)
query += " ORDER BY timestamp DESC LIMIT ?"
params.append(limit)
rows = self.conn.execute(query, params).fetchall()
Expand All @@ -231,6 +256,7 @@ def search_regex(
sender: str | None = None,
hours: int | None = None,
limit: int = 50,
topic_id: int | None = None,
) -> list[dict]:
"""Search messages by regex pattern."""
regex = re.compile(pattern, re.IGNORECASE)
Expand All @@ -246,6 +272,9 @@ def search_regex(
cutoff = (datetime.now(timezone.utc) - timedelta(hours=hours)).isoformat()
query += " AND timestamp >= ?"
params.append(cutoff)
if topic_id is not None:
query += " AND topic_id = ?"
params.append(topic_id)
query += " ORDER BY timestamp DESC LIMIT ?"
params.append(limit * 10)

Expand All @@ -266,6 +295,7 @@ def get_recent(
sender: str | None = None,
hours: int | None = 24,
limit: int = 500,
topic_id: int | None = None,
) -> list[dict]:
"""Get the latest messages, returned in chronological order."""
if hours is not None:
Expand All @@ -281,6 +311,9 @@ def get_recent(
if sender is not None:
base_query += " AND sender_name LIKE ?"
params.append(f"%{sender}%")
if topic_id is not None:
base_query += " AND topic_id = ?"
params.append(topic_id)
query = (
f"SELECT * FROM ({base_query} ORDER BY timestamp DESC LIMIT ?) ORDER BY timestamp ASC"
)
Expand All @@ -292,6 +325,7 @@ def get_today(
chat_id: int | None = None,
tz_offset_hours: int | None = None,
limit: int = 5000,
topic_id: int | None = None,
) -> list[dict]:
"""Get today's messages (in local timezone).

Expand Down Expand Up @@ -319,6 +353,9 @@ def get_today(
if chat_id is not None:
query += " AND chat_id = ?"
params.append(chat_id)
if topic_id is not None:
query += " AND topic_id = ?"
params.append(topic_id)
query += " ORDER BY chat_name, timestamp ASC LIMIT ?"
params.append(limit)
rows = self.conn.execute(query, params).fetchall()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class FakeEntity:
id: int
title: str
forum: bool = False


@dataclass
Expand All @@ -38,6 +39,7 @@ class FakeMessage:
date: datetime
message: str | None = None
_sender: object = None
reply_to: object = None

def __post_init__(self):
if self._sender is None:
Expand Down
Loading