diff --git a/cmd/msgvault/cmd/build_cache.go b/cmd/msgvault/cmd/build_cache.go index aee2ebb8..94d8f13d 100644 --- a/cmd/msgvault/cmd/build_cache.go +++ b/cmd/msgvault/cmd/build_cache.go @@ -17,6 +17,7 @@ import ( "github.com/spf13/cobra" "github.com/wesm/msgvault/internal/config" "github.com/wesm/msgvault/internal/query" + "github.com/wesm/msgvault/internal/store" ) var fullRebuild bool @@ -27,10 +28,17 @@ var fullRebuild bool // files (_last_sync.json, parquet directories) can corrupt the cache. var buildCacheMu sync.Mutex +// cacheSchemaVersion tracks the Parquet schema layout. Bump this whenever +// columns are added/removed/renamed in the COPY queries below so that +// incremental builds automatically trigger a full rebuild instead of +// producing Parquet files with mismatched schemas. +const cacheSchemaVersion = 3 // v3: schema migration adds phone_number etc. to existing DBs; force Parquet rebuild + // syncState tracks the last exported message ID for incremental updates. type syncState struct { LastMessageID int64 `json:"last_message_id"` LastSyncAt time.Time `json:"last_sync_at"` + SchemaVersion int `json:"schema_version,omitempty"` } var buildCacheCmd = &cobra.Command{ @@ -62,6 +70,20 @@ Use --full-rebuild to recreate all cache files from scratch.`, return fmt.Errorf("database not found: %s\nRun 'msgvault init-db' first", dbPath) } + // Ensure schema is up to date before building cache. + // Legacy databases may be missing columns (e.g. attachment_count, + // sender_id, message_type, phone_number) that the export queries + // reference. Running migrations first adds them. + s, err := store.Open(dbPath) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + if err := s.InitSchema(); err != nil { + s.Close() + return fmt.Errorf("init schema: %w", err) + } + s.Close() + result, err := buildCache(dbPath, analyticsDir, fullRebuild) if err != nil { return err @@ -101,8 +123,16 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er if data, err := os.ReadFile(stateFile); err == nil { var state syncState if json.Unmarshal(data, &state) == nil { - lastMessageID = state.LastMessageID - fmt.Printf("Incremental export from message_id > %d\n", lastMessageID) + if state.SchemaVersion != cacheSchemaVersion { + // Schema has changed — force a full rebuild. + fmt.Printf("Cache schema version mismatch (have v%d, need v%d). Forcing full rebuild.\n", + state.SchemaVersion, cacheSchemaVersion) + fullRebuild = true + lastMessageID = 0 + } else { + lastMessageID = state.LastMessageID + fmt.Printf("Incremental export from message_id > %d\n", lastMessageID) + } } } } @@ -231,7 +261,10 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er m.sent_at, m.size_estimate, m.has_attachments, + COALESCE(TRY_CAST(m.attachment_count AS INTEGER), 0) as attachment_count, m.deleted_from_source_at, + m.sender_id, + COALESCE(TRY_CAST(m.message_type AS VARCHAR), '') as message_type, CAST(EXTRACT(YEAR FROM m.sent_at) AS INTEGER) as year, CAST(EXTRACT(MONTH FROM m.sent_at) AS INTEGER) as month FROM sqlite_db.messages m @@ -321,7 +354,8 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er id, COALESCE(TRY_CAST(email_address AS VARCHAR), '') as email_address, COALESCE(TRY_CAST(domain AS VARCHAR), '') as domain, - COALESCE(TRY_CAST(display_name AS VARCHAR), '') as display_name + COALESCE(TRY_CAST(display_name AS VARCHAR), '') as display_name, + COALESCE(TRY_CAST(phone_number AS VARCHAR), '') as phone_number FROM sqlite_db.participants ) TO '%s/participants.parquet' ( FORMAT PARQUET, @@ -372,7 +406,8 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er COPY ( SELECT id, - COALESCE(TRY_CAST(source_conversation_id AS VARCHAR), '') as source_conversation_id + COALESCE(TRY_CAST(source_conversation_id AS VARCHAR), '') as source_conversation_id, + COALESCE(TRY_CAST(title AS VARCHAR), '') as title FROM sqlite_db.conversations ) TO '%s/conversations.parquet' ( FORMAT PARQUET, @@ -391,10 +426,11 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er exportedCount = 0 } - // Save sync state + // Save sync state with schema version for compatibility detection. state := syncState{ LastMessageID: maxID, LastSyncAt: time.Now(), + SchemaVersion: cacheSchemaVersion, } stateData, _ := json.Marshal(state) if err := os.WriteFile(stateFile, stateData, 0644); err != nil { @@ -592,15 +628,15 @@ func setupSQLiteSource(duckDB *sql.DB, dbPath string) (cleanup func(), err error query string typeOverrides string // DuckDB types parameter for read_csv_auto (empty = infer all) }{ - {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, deleted_from_source_at FROM messages WHERE sent_at IS NOT NULL", + {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, attachment_count, deleted_from_source_at, sender_id, message_type FROM messages WHERE sent_at IS NOT NULL", "types={'sent_at': 'TIMESTAMP', 'deleted_from_source_at': 'TIMESTAMP'}"}, {"message_recipients", "SELECT message_id, participant_id, recipient_type, display_name FROM message_recipients", ""}, {"message_labels", "SELECT message_id, label_id FROM message_labels", ""}, {"attachments", "SELECT message_id, size, filename FROM attachments", ""}, - {"participants", "SELECT id, email_address, domain, display_name FROM participants", ""}, + {"participants", "SELECT id, email_address, domain, display_name, phone_number FROM participants", ""}, {"labels", "SELECT id, name FROM labels", ""}, {"sources", "SELECT id, identifier FROM sources", ""}, - {"conversations", "SELECT id, source_conversation_id FROM conversations", ""}, + {"conversations", "SELECT id, source_conversation_id, title FROM conversations", ""}, } for _, t := range tables { diff --git a/cmd/msgvault/cmd/build_cache_test.go b/cmd/msgvault/cmd/build_cache_test.go index fe0e0a34..323b0861 100644 --- a/cmd/msgvault/cmd/build_cache_test.go +++ b/cmd/msgvault/cmd/build_cache_test.go @@ -51,7 +51,10 @@ func setupTestSQLite(t *testing.T) (string, func()) { received_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN DEFAULT FALSE, + attachment_count INTEGER DEFAULT 0, deleted_from_source_at TIMESTAMP, + sender_id INTEGER, + message_type TEXT NOT NULL DEFAULT 'email', UNIQUE(source_id, source_message_id) ); @@ -59,7 +62,8 @@ func setupTestSQLite(t *testing.T) (string, func()) { id INTEGER PRIMARY KEY, email_address TEXT NOT NULL UNIQUE, domain TEXT, - display_name TEXT + display_name TEXT, + phone_number TEXT ); CREATE TABLE message_recipients ( @@ -1128,13 +1132,13 @@ func TestBuildCache_EmptyDatabase(t *testing.T) { db, _ := sql.Open("sqlite3", dbPath) _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP); - CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT, domain TEXT, display_name TEXT); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); CREATE TABLE message_labels (message_id INTEGER, label_id INTEGER); CREATE TABLE attachments (message_id INTEGER, size INTEGER, filename TEXT); - CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT); + CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT, title TEXT); `) db.Close() @@ -1328,13 +1332,13 @@ func BenchmarkBuildCache(b *testing.B) { // Create schema _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP); - CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); CREATE TABLE message_labels (message_id INTEGER, label_id INTEGER); CREATE TABLE attachments (message_id INTEGER, size INTEGER, filename TEXT); - CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT); + CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT, title TEXT); INSERT INTO sources VALUES (1, 'test@gmail.com'); INSERT INTO labels VALUES (1, 'INBOX'), (2, 'Work'); `) @@ -1418,14 +1422,18 @@ func setupTestSQLiteEmpty(t *testing.T) (string, func()) { received_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN DEFAULT FALSE, + attachment_count INTEGER DEFAULT 0, deleted_from_source_at TIMESTAMP, + sender_id INTEGER, + message_type TEXT NOT NULL DEFAULT 'email', UNIQUE(source_id, source_message_id) ); CREATE TABLE participants ( id INTEGER PRIMARY KEY, email_address TEXT NOT NULL UNIQUE, domain TEXT, - display_name TEXT + display_name TEXT, + phone_number TEXT ); CREATE TABLE message_recipients ( id INTEGER PRIMARY KEY, @@ -1757,17 +1765,17 @@ func BenchmarkBuildCacheIncremental(b *testing.B) { // Create schema and initial data (10000 messages) _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP); - CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); CREATE TABLE message_labels (message_id INTEGER, label_id INTEGER); CREATE TABLE attachments (message_id INTEGER, size INTEGER, filename TEXT); - CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT); + CREATE TABLE conversations (id INTEGER PRIMARY KEY, source_conversation_id TEXT, title TEXT); INSERT INTO sources VALUES (1, 'test@gmail.com'); INSERT INTO labels VALUES (1, 'INBOX'); - INSERT INTO participants VALUES (1, 'alice@example.com', 'example.com', 'Alice'); - INSERT INTO participants VALUES (2, 'bob@example.com', 'example.com', 'Bob'); + INSERT INTO participants VALUES (1, 'alice@example.com', 'example.com', 'Alice', NULL); + INSERT INTO participants VALUES (2, 'bob@example.com', 'example.com', 'Bob', NULL); `) // Insert conversations to match messages diff --git a/cmd/msgvault/cmd/export_attachments.go b/cmd/msgvault/cmd/export_attachments.go index 272b4431..18f0a04c 100644 --- a/cmd/msgvault/cmd/export_attachments.go +++ b/cmd/msgvault/cmd/export_attachments.go @@ -42,6 +42,10 @@ func runExportAttachments(cmd *cobra.Command, args []string) error { } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + engine := query.NewSQLiteEngine(s.DB()) // Resolve message ID — try numeric first, fallback to Gmail ID diff --git a/cmd/msgvault/cmd/export_eml.go b/cmd/msgvault/cmd/export_eml.go index 08c5479b..88a471a4 100644 --- a/cmd/msgvault/cmd/export_eml.go +++ b/cmd/msgvault/cmd/export_eml.go @@ -86,6 +86,10 @@ func runExportEML(cmd *cobra.Command, messageRef, outputPath string) error { } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + engine := query.NewSQLiteEngine(s.DB()) resolved, err := resolveMessage(engine, cmd, messageRef) diff --git a/cmd/msgvault/cmd/import.go b/cmd/msgvault/cmd/import.go new file mode 100644 index 00000000..39ff99e8 --- /dev/null +++ b/cmd/msgvault/cmd/import.go @@ -0,0 +1,237 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/store" + "github.com/wesm/msgvault/internal/textutil" + "github.com/wesm/msgvault/internal/whatsapp" +) + +var ( + importType string + importPhone string + importMediaDir string + importContacts string + importLimit int + importDisplayName string +) + +var importCmd = &cobra.Command{ + Use: "import [path]", + Short: "Import messages from external sources", + Long: `Import messages from external message databases. + +Currently supported types: + whatsapp Import from a decrypted WhatsApp msgstore.db + +Examples: + msgvault import --type whatsapp --phone "+447700900000" /path/to/msgstore.db + msgvault import --type whatsapp --phone "+447700900000" --contacts ~/contacts.vcf /path/to/msgstore.db + msgvault import --type whatsapp --phone "+447700900000" --media-dir /path/to/Media /path/to/msgstore.db`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := MustBeLocal("import"); err != nil { + return err + } + + sourcePath := args[0] + + // Validate source file exists. + if _, err := os.Stat(sourcePath); err != nil { + return fmt.Errorf("source file not found: %w", err) + } + + switch strings.ToLower(importType) { + case "whatsapp": + return runWhatsAppImport(cmd, sourcePath) + default: + return fmt.Errorf("unsupported import type %q (supported: whatsapp)", importType) + } + }, +} + +func runWhatsAppImport(cmd *cobra.Command, sourcePath string) error { + // Validate phone number. + if importPhone == "" { + return fmt.Errorf("--phone is required for WhatsApp import (E.164 format, e.g., +447700900000)") + } + if !strings.HasPrefix(importPhone, "+") { + return fmt.Errorf("phone number must be in E.164 format (starting with +), got %q", importPhone) + } + + // Validate media dir if provided. + if importMediaDir != "" { + if info, err := os.Stat(importMediaDir); err != nil || !info.IsDir() { + return fmt.Errorf("media directory not found or not a directory: %s", importMediaDir) + } + } + + // Open database. + dbPath := cfg.DatabaseDSN() + s, err := store.Open(dbPath) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer s.Close() + + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + + // Set up context with cancellation. + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + // Handle Ctrl+C gracefully. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + fmt.Println("\nInterrupted. Saving checkpoint...") + cancel() + }() + + // Build import options. + opts := whatsapp.DefaultOptions() + opts.Phone = importPhone + opts.DisplayName = importDisplayName + opts.MediaDir = importMediaDir + opts.AttachmentsDir = cfg.AttachmentsDir() + opts.Limit = importLimit + + // Create importer with CLI progress. + progress := &ImportCLIProgress{} + importer := whatsapp.NewImporter(s, progress) + + fmt.Printf("Importing WhatsApp messages from %s\n", sourcePath) + fmt.Printf("Phone: %s\n", importPhone) + if importMediaDir != "" { + fmt.Printf("Media: %s\n", importMediaDir) + } + if importLimit > 0 { + fmt.Printf("Limit: %d messages\n", importLimit) + } + fmt.Println() + + summary, err := importer.Import(ctx, sourcePath, opts) + if err != nil { + if ctx.Err() != nil { + fmt.Println("\nImport interrupted. Run again to continue.") + return nil + } + return fmt.Errorf("import failed: %w", err) + } + + // Import contacts if provided. + if importContacts != "" { + fmt.Printf("\nImporting contacts from %s...\n", importContacts) + matched, total, err := whatsapp.ImportContacts(s, importContacts) + if err != nil { + return fmt.Errorf("contact import: %w", err) + } else { + fmt.Printf(" Contacts: %d in file, %d phone numbers matched to participants\n", total, matched) + } + } + + // Print summary. + fmt.Println() + fmt.Println("Import complete!") + fmt.Printf(" Duration: %s\n", summary.Duration.Round(time.Second)) + fmt.Printf(" Chats: %d\n", summary.ChatsProcessed) + fmt.Printf(" Messages: %d processed, %d added, %d skipped\n", + summary.MessagesProcessed, summary.MessagesAdded, summary.MessagesSkipped) + fmt.Printf(" Participants: %d\n", summary.Participants) + fmt.Printf(" Reactions: %d\n", summary.ReactionsAdded) + fmt.Printf(" Attachments: %d found", summary.AttachmentsFound) + if summary.MediaCopied > 0 { + fmt.Printf(", %d files copied", summary.MediaCopied) + } + fmt.Println() + if summary.Errors > 0 { + fmt.Printf(" Errors: %d\n", summary.Errors) + } + + if summary.MessagesAdded > 0 { + rate := float64(summary.MessagesAdded) / summary.Duration.Seconds() + fmt.Printf(" Rate: %.0f messages/sec\n", rate) + } + + return nil +} + +// ImportCLIProgress implements whatsapp.ImportProgress for terminal output. +type ImportCLIProgress struct { + startTime time.Time + lastPrint time.Time + currentChat string +} + +func (p *ImportCLIProgress) OnStart() { + p.startTime = time.Now() + p.lastPrint = time.Now() +} + +func (p *ImportCLIProgress) OnChatStart(chatJID, chatTitle string, messageCount int) { + p.currentChat = chatTitle + // Don't print every chat start — too noisy for 13k+ chats. +} + +func (p *ImportCLIProgress) OnProgress(processed, added, skipped int64) { + // Throttle output to every 2 seconds. + if time.Since(p.lastPrint) < 2*time.Second { + return + } + p.lastPrint = time.Now() + + elapsed := time.Since(p.startTime) + rate := 0.0 + if elapsed.Seconds() >= 1 { + rate = float64(added) / elapsed.Seconds() + } + + elapsedStr := formatDuration(elapsed) + + chatStr := "" + if p.currentChat != "" { + // Truncate long chat names and sanitize to prevent terminal injection. + name := textutil.SanitizeTerminal(p.currentChat) + if len(name) > 30 { + name = name[:27] + "..." + } + chatStr = fmt.Sprintf(" | Chat: %s", name) + } + + fmt.Printf("\r Processed: %d | Added: %d | Skipped: %d | Rate: %.0f/s | Elapsed: %s%s ", + processed, added, skipped, rate, elapsedStr, chatStr) +} + +func (p *ImportCLIProgress) OnChatComplete(chatJID string, messagesAdded int64) { + // Quiet — progress line shows the aggregate. +} + +func (p *ImportCLIProgress) OnComplete(summary *whatsapp.ImportSummary) { + fmt.Println() // Clear the progress line. +} + +func (p *ImportCLIProgress) OnError(err error) { + fmt.Printf("\nWarning: %s\n", textutil.SanitizeTerminal(err.Error())) +} + +func init() { + importCmd.Flags().StringVar(&importType, "type", "", "import source type (required: whatsapp)") + importCmd.Flags().StringVar(&importPhone, "phone", "", "your phone number in E.164 format (required for whatsapp)") + importCmd.Flags().StringVar(&importMediaDir, "media-dir", "", "path to decrypted Media folder (optional)") + importCmd.Flags().StringVar(&importContacts, "contacts", "", "path to contacts .vcf file for name resolution (optional)") + importCmd.Flags().IntVar(&importLimit, "limit", 0, "limit number of messages (for testing)") + importCmd.Flags().StringVar(&importDisplayName, "display-name", "", "display name for the phone owner") + _ = importCmd.MarkFlagRequired("type") + rootCmd.AddCommand(importCmd) +} diff --git a/cmd/msgvault/cmd/list_domains.go b/cmd/msgvault/cmd/list_domains.go index ae0fbf63..e9b9f24d 100644 --- a/cmd/msgvault/cmd/list_domains.go +++ b/cmd/msgvault/cmd/list_domains.go @@ -34,6 +34,10 @@ Examples: } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/list_labels.go b/cmd/msgvault/cmd/list_labels.go index 775949a0..343934b9 100644 --- a/cmd/msgvault/cmd/list_labels.go +++ b/cmd/msgvault/cmd/list_labels.go @@ -34,6 +34,10 @@ Examples: } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/list_senders.go b/cmd/msgvault/cmd/list_senders.go index f18a7229..98ea3c92 100644 --- a/cmd/msgvault/cmd/list_senders.go +++ b/cmd/msgvault/cmd/list_senders.go @@ -34,6 +34,10 @@ Examples: } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/repair_encoding.go b/cmd/msgvault/cmd/repair_encoding.go index 0ca655d6..b1dc99d8 100644 --- a/cmd/msgvault/cmd/repair_encoding.go +++ b/cmd/msgvault/cmd/repair_encoding.go @@ -44,6 +44,10 @@ charset detection issues in the MIME parser.`, } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + return repairEncoding(s) }, } diff --git a/cmd/msgvault/cmd/show_message.go b/cmd/msgvault/cmd/show_message.go index f7feb326..f64dae4c 100644 --- a/cmd/msgvault/cmd/show_message.go +++ b/cmd/msgvault/cmd/show_message.go @@ -82,6 +82,10 @@ func showLocalMessage(cmd *cobra.Command, idStr string) error { } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/update_account.go b/cmd/msgvault/cmd/update_account.go index 435dcb6c..8783974d 100644 --- a/cmd/msgvault/cmd/update_account.go +++ b/cmd/msgvault/cmd/update_account.go @@ -34,6 +34,10 @@ Examples: } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + source, err := s.GetSourceByIdentifier(email) if err != nil { return fmt.Errorf("get account: %w", err) diff --git a/cmd/msgvault/cmd/verify.go b/cmd/msgvault/cmd/verify.go index 4427584f..92d55a65 100644 --- a/cmd/msgvault/cmd/verify.go +++ b/cmd/msgvault/cmd/verify.go @@ -49,6 +49,10 @@ Examples: } defer s.Close() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + // Create OAuth manager and get token source oauthMgr, err := oauth.NewManager(cfg.OAuth.ClientSecrets, cfg.TokensDir(), logger) if err != nil { diff --git a/docs/plans/2026-02-17-multi-source-messaging.md b/docs/plans/2026-02-17-multi-source-messaging.md new file mode 100644 index 00000000..48b7b415 --- /dev/null +++ b/docs/plans/2026-02-17-multi-source-messaging.md @@ -0,0 +1,164 @@ +# Multi-Source Messaging Support + +**Issue:** [wesm/msgvault#136](https://github.com/wesm/msgvault/issues/136) +**Author:** Ed Dowding +**Date:** 2026-02-17 +**Status:** Draft for review + +## Goal + +Make msgvault a universal message archive — not just Gmail. Starting with WhatsApp, but ensuring the design works for iMessage, Telegram, SMS, and other chat platforms. + +## Good News: The Schema Is Already Ready + +The existing schema was designed for this. Key fields already in place: + +| Table | Multi-source fields | +|-------|-------------------| +| `sources` | `source_type` ('gmail', 'whatsapp', 'apple_messages', 'google_messages'), `identifier` (email or phone), `sync_cursor` (platform-agnostic) | +| `messages` | `message_type` ('email', 'imessage', 'sms', 'mms', 'rcs', 'whatsapp'), `is_edited`, `is_forwarded`, `delivered_at`, `read_at` | +| `conversations` | `conversation_type` ('email_thread', 'group_chat', 'direct_chat', 'channel') | +| `participants` | `phone_number` (E.164), `canonical_id` (cross-platform dedup) | +| `participant_identifiers` | `identifier_type` ('email', 'phone', 'apple_id', 'whatsapp') | +| `attachments` | `media_type` ('image', 'video', 'audio', 'sticker', 'gif', 'voice_note') | +| `reactions` | `reaction_type` ('tapback', 'emoji', 'like') | +| `message_raw` | `raw_format` ('mime', 'imessage_archive', 'whatsapp_json', 'rcs_json') | + +**No schema migrations needed.** The store layer (`UpsertMessage`, `GetOrCreateSource`, etc.) is already generic — it accepts any `source_type` and `message_type`. The tight coupling to Gmail is only in the sync pipeline and CLI commands. + +## CLI Design + +Per Wes's feedback, use `--type` not `--whatsapp`: + +```bash +# Add accounts +msgvault add-account user@gmail.com # default: --type gmail +msgvault add-account --type whatsapp "+447700900000" # WhatsApp via phone +msgvault add-account --type imessage # no identifier needed (local DB) + +# Sync +msgvault sync-full # all sources +msgvault sync-full user@gmail.com # specific account +msgvault sync-full "+447700900000" # auto-detects type from sources table +msgvault sync-full --type whatsapp # all WhatsApp accounts +msgvault sync-incremental # incremental for all sources +``` + +**Account identifiers** use E.164 phone numbers for phone-based sources (`+447700900000`), email addresses for email-based sources. The existing `UNIQUE(source_type, identifier)` constraint means the same phone number can be both a WhatsApp and iMessage account. + +## How Each Platform Syncs + +The fundamental difference: Gmail is pull-based (fetch any message anytime), most chat platforms are push-based (stream messages in real time). Each platform gets its own package under `internal/` that knows how to sync into the shared store. + +| Platform | Sync model | History access | Auth | Identifier | +|----------|-----------|---------------|------|------------| +| **Gmail** | Pull via API | Full random access | OAuth2 (browser or device flow) | Email address | +| **WhatsApp** | Connect + stream | One-time dump at pairing, then forward-only | QR code or phone pairing code | E.164 phone | +| **iMessage** | Read local SQLite | Full (reads `~/Library/Messages/chat.db`) | macOS Full Disk Access | None (local) | +| **Telegram** | Pull via TDLib | Full history via API | Phone + code | E.164 phone | +| **SMS/Android** | Read local SQLite | Full (reads `mmssms.db` from backup) | File access | E.164 phone | + +No abstract `Provider` interface up front — just build each platform's sync as a standalone package, and extract common patterns once we have two working. YAGNI. + +## WhatsApp Specifics (Phase 1) + +### Library: whatsmeow + +[whatsmeow](https://github.com/tulir/whatsmeow) is a pure Go implementation of the WhatsApp Web multi-device protocol. Production-grade — it powers the [mautrix-whatsapp](https://github.com/mautrix/whatsapp) Matrix bridge (2,200+ stars). Actively maintained (last commit: Feb 2026). + +### Auth Flow + +1. User runs `msgvault add-account --type whatsapp "+447700900000"` +2. Terminal displays QR code (or pairing code with `--headless`) +3. User scans with WhatsApp on their phone +4. Session credentials stored in SQLite (alongside msgvault's main DB) +5. Session persists across restarts — no re-scanning needed + +Session expires if the primary phone doesn't connect to internet for 14 days, or after ~30 days of inactivity. + +### Sync Model + +**Critical constraint:** WhatsApp history is a one-time dump, not an on-demand API. + +``` +First sync: + connect → receive history dump (HistorySync event) → stream until caught up → disconnect + +Subsequent syncs: + connect → stream new messages since last cursor → disconnect +``` + +On-demand historical backfill exists (`BuildHistorySyncRequest`) but is documented as unreliable, especially for groups. Design accordingly: treat initial history as best-effort, then reliably capture everything going forward. + +### Media Must Be Downloaded Immediately + +WhatsApp media URLs expire after ~30 days. Unlike Gmail where you can fetch any attachment anytime, WhatsApp media must be downloaded and stored locally at sync time. The existing content-addressed attachment storage (SHA-256 dedup) works perfectly for this. + +### Message Type Mapping + +| WhatsApp | msgvault field | Value | +|----------|---------------|-------| +| Text message | `messages.message_type` | `'whatsapp'` | +| Image/Video/Audio | `attachments.media_type` | `'image'`, `'video'`, `'audio'` | +| Voice note | `attachments.media_type` | `'voice_note'` | +| Sticker | `attachments.media_type` | `'sticker'` | +| Document | `attachments.media_type` | `'document'` | +| Reaction (emoji) | `reactions.reaction_type` | `'emoji'` | +| Reply/Quote | `messages.reply_to_message_id` | FK to parent message | +| Forwarded | `messages.is_forwarded` | `true` | +| Edited | `messages.is_edited` | `true` | +| Read receipt | `messages.read_at` | Timestamp | +| Delivery receipt | `messages.delivered_at` | Timestamp | +| Group chat | `conversations.conversation_type` | `'group_chat'` | +| 1:1 chat | `conversations.conversation_type` | `'direct_chat'` | +| Sender JID | `participant_identifiers.identifier_type` | `'whatsapp'`, value = `447700900000@s.whatsapp.net` | +| Sender phone | `participants.phone_number` | `+447700900000` (E.164) | +| Raw protobuf | `message_raw.raw_format` | `'whatsapp_protobuf'` | + +### What Changes in Existing Code + +**New package:** `internal/whatsapp/` — self-contained, no changes to existing Gmail code. + +**Small changes needed:** +- `cmd/msgvault/cmd/addaccount.go`: Add `--type` flag, dispatch to WhatsApp auth when type is `"whatsapp"` +- `cmd/msgvault/cmd/syncfull.go`: Currently hardcodes `ListSources("gmail")` — change to `ListSources("")` (all types) with a type-based dispatcher +- `internal/store/`: Add `EnsureParticipantByPhone()` method (currently only handles email-based participants) +- `internal/store/`: Add `'member'` as a valid `recipient_type` for group chat participants + +**No changes to:** schema, query engine, TUI, MCP server, HTTP API, or any consumer. Messages from WhatsApp will appear in search, aggregation, and all views automatically because consumers operate on the generic `messages` table. + +## Risks + +| Risk | Severity | Mitigation | +|------|----------|------------| +| **Account ban/warning** | High | WhatsApp TOS prohibits unofficial clients. Read-only archival is lower risk than bots, but not zero. Document prominently. Recommend a dedicated/secondary number for testing. | +| **History dump is incomplete** | Medium | WhatsApp server decides how much history to send at pairing. Design as "best effort snapshot + reliable stream forward." | +| **whatsmeow protocol breakage** | Medium | WhatsApp changes their protocol regularly. Pin whatsmeow version, expect occasional breakage, track upstream releases. | +| **Media URL expiration** | Low | Download everything at sync time. Already mitigated by design. | +| **Phone must be online every 14 days** | Low | Document requirement. Could add a warning in `sync` output if session is stale. | + +## How Other Platforms Would Plug In Later + +Each gets its own `internal//` package that syncs into the store. Brief notes on feasibility: + +**iMessage** (macOS only): Read `~/Library/Messages/chat.db` directly. Full history available. Timestamps use Apple epoch (nanoseconds since 2001-01-01). Tapbacks stored as separate messages referencing parent via `associated_message_guid` — would map to `reactions` table. Requires Full Disk Access permission. No network needed. + +**Telegram**: TDLib (official C++ library with Go bindings) or import from Desktop export JSON. Full history available via API. Unique features: channels, supergroups, forums, scheduled messages, silent messages. User IDs are numeric (not phone-based) but phone is the auth method. + +**SMS/Android**: Import from `mmssms.db` backup. Simple data model (phone, timestamp, body). MMS attachments in `part` table. No reactions, no threading, no edits. + +**Signal**: Hardest. Desktop DB is SQLCipher-encrypted. Schema changes frequently (215+ migration versions). No official export API. Feasible but fragile. + +## Implementation Phases + +**Phase 1 — CLI + dispatcher (no new platforms):** +Add `--type` flag. Change sync dispatch from Gmail-only to type-based. All existing behavior unchanged. + +**Phase 2 — WhatsApp sync:** +`internal/whatsapp/` package. QR pairing. History dump. Forward streaming. Media download. Phone participant handling. + +**Phase 3 — WhatsApp features:** +Reactions, replies, groups with metadata, voice notes, stickers, read receipts. + +**Phase 4 — Next platform (iMessage or Telegram):** +By this point we'll have two implementations and can extract common patterns if they emerge naturally. Not before. diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index 1491c466..b6e4a3c2 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -323,7 +323,12 @@ func (h *handlers) listMessages(ctx context.Context, req mcp.CallToolRequest) (* } if v, ok := args["from"].(string); ok && v != "" { - filter.Sender = v + // If it looks like an email address, filter by email; otherwise by display name. + if strings.Contains(v, "@") || strings.HasPrefix(v, "+") { + filter.Sender = v + } else { + filter.SenderName = v + } } if v, ok := args["to"].(string); ok && v != "" { filter.Recipient = v diff --git a/internal/query/duckdb.go b/internal/query/duckdb.go index d8465195..fadb8a80 100644 --- a/internal/query/duckdb.go +++ b/internal/query/duckdb.go @@ -40,6 +40,12 @@ type DuckDBEngine struct { hasSQLiteScanner bool // true if DuckDB's sqlite extension is loaded tempTableSeq atomic.Uint64 // Unique suffix for temp tables to avoid concurrent collisions + // optionalCols tracks which columns exist in each Parquet table's schema. + // Used to gracefully handle stale cache files that lack newer columns + // (e.g. phone_number, attachment_count, sender_id, message_type added in PR #160). + // Map: table_name -> column_name -> exists_in_parquet + optionalCols map[string]map[string]bool + // Search result cache: keeps the materialized temp table alive across // pagination calls for the same search query, avoiding repeated Parquet scans. searchCacheMu sync.Mutex // protects cache fields from concurrent goroutines @@ -122,14 +128,39 @@ func NewDuckDBEngine(analyticsDir string, sqlitePath string, sqliteDB *sql.DB, o sqliteEngine = NewSQLiteEngine(sqliteDB) } - return &DuckDBEngine{ + engine := &DuckDBEngine{ db: db, analyticsDir: analyticsDir, sqlitePath: sqlitePath, sqliteDB: sqliteDB, sqliteEngine: sqliteEngine, hasSQLiteScanner: hasSQLiteScanner, - }, nil + } + + // Probe Parquet schemas for optional columns added in PR #160 (WhatsApp import). + // Old cache files may lack these columns; we'll supply defaults in parquetCTEs(). + engine.optionalCols = map[string]map[string]bool{ + "participants": engine.probeParquetColumns(engine.parquetPath("participants"), false), + "messages": engine.probeParquetColumns(engine.parquetGlob(), true), + "conversations": engine.probeParquetColumns(engine.parquetPath("conversations"), false), + } + var missing []string + for _, col := range []struct{ table, col string }{ + {"participants", "phone_number"}, + {"messages", "attachment_count"}, + {"messages", "sender_id"}, + {"messages", "message_type"}, + {"conversations", "title"}, + } { + if !engine.optionalCols[col.table][col.col] { + missing = append(missing, col.table+"."+col.col) + } + } + if len(missing) > 0 { + log.Printf("[warn] Parquet cache missing columns %v — run 'msgvault build-cache --full-rebuild' to update", missing) + } + + return engine, nil } // Close releases DuckDB resources, including any cached search temp table. @@ -156,6 +187,48 @@ func (e *DuckDBEngine) parquetPath(table string) string { return filepath.Join(e.analyticsDir, table, "*.parquet") } +// probeParquetColumns checks which columns exist in a Parquet table's files. +// Returns a map of column_name -> true for columns that exist. +// On any error (files missing, unreadable, etc.), returns an empty map — callers +// should treat absent keys as "column does not exist" and supply defaults. +func (e *DuckDBEngine) probeParquetColumns(pathPattern string, hivePartitioning bool) map[string]bool { + cols := make(map[string]bool) + hiveOpt := "" + if hivePartitioning { + hiveOpt = ", hive_partitioning=true" + } + escapedPath := strings.ReplaceAll(pathPattern, "'", "''") + query := fmt.Sprintf("DESCRIBE SELECT * FROM read_parquet('%s'%s)", escapedPath, hiveOpt) + rows, err := e.db.Query(query) + if err != nil { + // No Parquet files or unreadable — treat all optional cols as missing. + return cols + } + defer rows.Close() + for rows.Next() { + var colName, colType, isNull, key, dflt, extra sql.NullString + if err := rows.Scan(&colName, &colType, &isNull, &key, &dflt, &extra); err != nil { + continue + } + if colName.Valid { + cols[colName.String] = true + } + } + return cols +} + +// hasCol returns true if the named column exists in the Parquet schema for the given table. +func (e *DuckDBEngine) hasCol(table, col string) bool { + if e.optionalCols == nil { + return true // no probe data — assume present (backwards compatible) + } + tbl, ok := e.optionalCols[table] + if !ok { + return true // table not probed — assume present + } + return tbl[col] +} + // parquetCTEs returns common CTEs for reading all Parquet tables. // This is used by aggregate queries that need to join across tables. // parquetCTEs returns the WITH clause body that defines CTEs for all Parquet @@ -163,19 +236,83 @@ func (e *DuckDBEngine) parquetPath(table string) string { // REPLACE syntax, because Parquet schema inference from SQLite can store // integer/boolean columns as VARCHAR, causing type mismatch errors in JOINs // and COALESCE expressions. +// +// Optional columns (phone_number, attachment_count, sender_id, message_type) +// are handled gracefully: if the Parquet file predates their addition, they +// are synthesised with sensible defaults instead of causing a binder error. func (e *DuckDBEngine) parquetCTEs() string { + // --- messages CTE --- + msgReplace := []string{ + "CAST(id AS BIGINT) AS id", + "CAST(source_id AS BIGINT) AS source_id", + "CAST(source_message_id AS VARCHAR) AS source_message_id", + "CAST(conversation_id AS BIGINT) AS conversation_id", + "CAST(subject AS VARCHAR) AS subject", + "CAST(snippet AS VARCHAR) AS snippet", + "CAST(size_estimate AS BIGINT) AS size_estimate", + "COALESCE(TRY_CAST(has_attachments AS BOOLEAN), false) AS has_attachments", + } + var msgExtra []string + if e.hasCol("messages", "attachment_count") { + msgReplace = append(msgReplace, "COALESCE(TRY_CAST(attachment_count AS INTEGER), 0) AS attachment_count") + } else { + msgExtra = append(msgExtra, "0 AS attachment_count") + } + if e.hasCol("messages", "sender_id") { + msgReplace = append(msgReplace, "TRY_CAST(sender_id AS BIGINT) AS sender_id") + } else { + msgExtra = append(msgExtra, "NULL::BIGINT AS sender_id") + } + if e.hasCol("messages", "message_type") { + msgReplace = append(msgReplace, "COALESCE(CAST(message_type AS VARCHAR), '') AS message_type") + } else { + msgExtra = append(msgExtra, "'' AS message_type") + } + msgCTE := fmt.Sprintf("SELECT * REPLACE (\n\t\t\t\t%s\n\t\t\t)", strings.Join(msgReplace, ",\n\t\t\t\t")) + if len(msgExtra) > 0 { + msgCTE += ", " + strings.Join(msgExtra, ", ") + } + msgCTE += fmt.Sprintf(" FROM read_parquet('%s', hive_partitioning=true, union_by_name=true)", e.parquetGlob()) + + // --- participants CTE --- + pReplace := []string{ + "CAST(id AS BIGINT) AS id", + "CAST(email_address AS VARCHAR) AS email_address", + "CAST(domain AS VARCHAR) AS domain", + "CAST(display_name AS VARCHAR) AS display_name", + } + var pExtra []string + if e.hasCol("participants", "phone_number") { + pReplace = append(pReplace, "COALESCE(CAST(phone_number AS VARCHAR), '') AS phone_number") + } else { + pExtra = append(pExtra, "'' AS phone_number") + } + pCTE := fmt.Sprintf("SELECT * REPLACE (\n\t\t\t\t%s\n\t\t\t)", strings.Join(pReplace, ",\n\t\t\t\t")) + if len(pExtra) > 0 { + pCTE += ", " + strings.Join(pExtra, ", ") + } + pCTE += fmt.Sprintf(" FROM read_parquet('%s')", e.parquetPath("participants")) + + // --- conversations CTE --- + convReplace := []string{ + "CAST(id AS BIGINT) AS id", + "CAST(source_conversation_id AS VARCHAR) AS source_conversation_id", + } + var convExtra []string + if e.hasCol("conversations", "title") { + convReplace = append(convReplace, "COALESCE(CAST(title AS VARCHAR), '') AS title") + } else { + convExtra = append(convExtra, "'' AS title") + } + convCTE := fmt.Sprintf("SELECT * REPLACE (\n\t\t\t\t%s\n\t\t\t)", strings.Join(convReplace, ",\n\t\t\t\t")) + if len(convExtra) > 0 { + convCTE += ", " + strings.Join(convExtra, ", ") + } + convCTE += fmt.Sprintf(" FROM read_parquet('%s')", e.parquetPath("conversations")) + return fmt.Sprintf(` msg AS ( - SELECT * REPLACE ( - CAST(id AS BIGINT) AS id, - CAST(source_id AS BIGINT) AS source_id, - CAST(source_message_id AS VARCHAR) AS source_message_id, - CAST(conversation_id AS BIGINT) AS conversation_id, - CAST(subject AS VARCHAR) AS subject, - CAST(snippet AS VARCHAR) AS snippet, - CAST(size_estimate AS BIGINT) AS size_estimate, - COALESCE(TRY_CAST(has_attachments AS BOOLEAN), false) AS has_attachments - ) FROM read_parquet('%s', hive_partitioning=true) + %s ), mr AS ( SELECT * REPLACE ( @@ -186,12 +323,7 @@ func (e *DuckDBEngine) parquetCTEs() string { ) FROM read_parquet('%s') ), p AS ( - SELECT * REPLACE ( - CAST(id AS BIGINT) AS id, - CAST(email_address AS VARCHAR) AS email_address, - CAST(domain AS VARCHAR) AS domain, - CAST(display_name AS VARCHAR) AS display_name - ) FROM read_parquet('%s') + %s ), lbl AS ( SELECT * REPLACE ( @@ -218,19 +350,16 @@ func (e *DuckDBEngine) parquetCTEs() string { ) FROM read_parquet('%s') ), conv AS ( - SELECT * REPLACE ( - CAST(id AS BIGINT) AS id, - CAST(source_conversation_id AS VARCHAR) AS source_conversation_id - ) FROM read_parquet('%s') + %s ) - `, e.parquetGlob(), + `, msgCTE, e.parquetPath("message_recipients"), - e.parquetPath("participants"), + pCTE, e.parquetPath("labels"), e.parquetPath("message_labels"), e.parquetPath("attachments"), e.parquetPath("sources"), - e.parquetPath("conversations")) + convCTE) } // escapeILIKE escapes ILIKE wildcard characters (% and _) in user input. @@ -686,45 +815,62 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []in conditions = append(conditions, "msg.deleted_from_source_at IS NULL") } - // Sender filter - use EXISTS subquery (becomes semi-join) + // Sender filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) + // Also checks phone_number for phone-based lookups (e.g., from:+447...) if filter.Sender != "" { - conditions = append(conditions, `EXISTS ( + conditions = append(conditions, `(EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' - AND p.email_address = ? - )`) - args = append(args, filter.Sender) + AND (p.email_address = ? OR p.phone_number = ?) + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND (p.email_address = ? OR p.phone_number = ?) + ))`) + args = append(args, filter.Sender, filter.Sender, filter.Sender, filter.Sender) } else if filter.MatchesEmpty(ViewSenders) { - conditions = append(conditions, `NOT EXISTS ( + // A message has an "empty sender" only if it has no from-recipient AND no direct sender_id. + conditions = append(conditions, `(NOT EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' - AND p.email_address IS NOT NULL - AND p.email_address != '' - )`) + AND ( + (p.email_address IS NOT NULL AND p.email_address != '') OR + (p.phone_number IS NOT NULL AND p.phone_number != '') + ) + ) AND msg.sender_id IS NULL)`) } - // Sender name filter - use EXISTS subquery (becomes semi-join) + // Sender name filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) if filter.SenderName != "" { - conditions = append(conditions, `EXISTS ( + conditions = append(conditions, `(EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) = ? - )`) - args = append(args, filter.SenderName) + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) = ? + ))`) + args = append(args, filter.SenderName, filter.SenderName) } else if filter.MatchesEmpty(ViewSenderNames) { - conditions = append(conditions, `NOT EXISTS ( + // A message has an "empty sender name" only if it has no from-recipient name AND no direct sender_id with a name. + conditions = append(conditions, `(NOT EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) IS NOT NULL - )`) + ) AND NOT EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) IS NOT NULL + ))`) } // Recipient filter - use EXISTS subquery (becomes semi-join) @@ -1053,12 +1199,24 @@ func (e *DuckDBEngine) ListMessages(ctx context.Context, filter MessageFilter) ( msg_sender AS ( SELECT mr.message_id, FIRST(p.email_address) as from_email, - FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name + FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name, + FIRST(COALESCE(p.phone_number, '')) as from_phone FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.recipient_type = 'from' AND mr.message_id IN (SELECT id FROM filtered_msgs) GROUP BY mr.message_id + ), + direct_sender AS ( + SELECT msg.id as message_id, + COALESCE(p.email_address, '') as from_email, + COALESCE(p.display_name, '') as from_name, + COALESCE(p.phone_number, '') as from_phone + FROM msg + JOIN filtered_msgs fm ON fm.id = msg.id + JOIN p ON p.id = msg.sender_id + WHERE msg.sender_id IS NOT NULL + AND msg.id NOT IN (SELECT message_id FROM msg_sender) ) SELECT msg.id, @@ -1067,15 +1225,20 @@ func (e *DuckDBEngine) ListMessages(ctx context.Context, filter MessageFilter) ( COALESCE(c.source_conversation_id, '') as source_conversation_id, COALESCE(msg.subject, '') as subject, COALESCE(msg.snippet, '') as snippet, - COALESCE(ms.from_email, '') as from_email, - COALESCE(ms.from_name, '') as from_name, + COALESCE(ms.from_email, ds.from_email, '') as from_email, + COALESCE(ms.from_name, ds.from_name, '') as from_name, + COALESCE(ms.from_phone, ds.from_phone, '') as from_phone, msg.sent_at, COALESCE(msg.size_estimate, 0) as size_estimate, COALESCE(msg.has_attachments, false) as has_attachments, - msg.deleted_from_source_at + COALESCE(msg.attachment_count, 0) as attachment_count, + msg.deleted_from_source_at, + COALESCE(msg.message_type, '') as message_type, + COALESCE(c.title, '') as conv_title FROM msg JOIN filtered_msgs fm ON fm.id = msg.id LEFT JOIN msg_sender ms ON ms.message_id = msg.id + LEFT JOIN direct_sender ds ON ds.message_id = msg.id LEFT JOIN conv c ON c.id = msg.conversation_id ORDER BY %s `, e.parquetCTEs(), where, orderBy, orderBy) @@ -1102,10 +1265,14 @@ func (e *DuckDBEngine) ListMessages(ctx context.Context, filter MessageFilter) ( &msg.Snippet, &msg.FromEmail, &msg.FromName, + &msg.FromPhone, &sentAt, &msg.SizeEstimate, &msg.HasAttachments, + &msg.AttachmentCount, &deletedAt, + &msg.MessageType, + &msg.ConversationTitle, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } @@ -1425,6 +1592,11 @@ func (e *DuckDBEngine) GetGmailIDsByFilter(ctx context.Context, filter MessageFi // Always exclude deleted messages conditions = append(conditions, "msg.deleted_from_source_at IS NULL") + // Scope to Gmail messages only — this function is used for Gmail-specific + // deletion/staging workflows and must not return WhatsApp or other source IDs. + // In the Parquet fallback, we filter by message_type since sources aren't in the cache. + conditions = append(conditions, "(msg.message_type = '' OR msg.message_type = 'email' OR msg.message_type IS NULL)") + if filter.SourceID != nil { conditions = append(conditions, "msg.source_id = ?") args = append(args, *filter.SourceID) @@ -1432,25 +1604,33 @@ func (e *DuckDBEngine) GetGmailIDsByFilter(ctx context.Context, filter MessageFi // Use EXISTS subqueries for filtering (becomes semi-joins, no duplicates) if filter.Sender != "" { - conditions = append(conditions, `EXISTS ( + conditions = append(conditions, `(EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' - AND p.email_address = ? - )`) - args = append(args, filter.Sender) + AND (p.email_address = ? OR p.phone_number = ?) + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND (p.email_address = ? OR p.phone_number = ?) + ))`) + args = append(args, filter.Sender, filter.Sender, filter.Sender, filter.Sender) } if filter.SenderName != "" { - conditions = append(conditions, `EXISTS ( + conditions = append(conditions, `(EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type = 'from' AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) = ? - )`) - args = append(args, filter.SenderName) + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND COALESCE(NULLIF(TRIM(p.display_name), ''), p.email_address) = ? + ))`) + args = append(args, filter.SenderName, filter.SenderName) } if filter.Recipient != "" { @@ -1612,11 +1792,22 @@ func (e *DuckDBEngine) SearchFast(ctx context.Context, q *search.Query, filter M msg_sender AS ( SELECT mr.message_id, FIRST(p.email_address) as from_email, - FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name + FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name, + FIRST(COALESCE(p.phone_number, '')) as from_phone FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.recipient_type = 'from' GROUP BY mr.message_id + ), + direct_sender AS ( + SELECT msg.id as message_id, + COALESCE(p.email_address, '') as from_email, + COALESCE(p.display_name, '') as from_name, + COALESCE(p.phone_number, '') as from_phone + FROM msg + JOIN p ON p.id = msg.sender_id + WHERE msg.sender_id IS NOT NULL + AND msg.id NOT IN (SELECT message_id FROM msg_sender) ) SELECT COALESCE(msg.id, 0) as id, @@ -1625,8 +1816,9 @@ func (e *DuckDBEngine) SearchFast(ctx context.Context, q *search.Query, filter M COALESCE(c.source_conversation_id, '') as source_conversation_id, COALESCE(msg.subject, '') as subject, COALESCE(msg.snippet, '') as snippet, - COALESCE(ms.from_email, '') as from_email, - COALESCE(ms.from_name, '') as from_name, + COALESCE(ms.from_email, ds.from_email, '') as from_email, + COALESCE(ms.from_name, ds.from_name, '') as from_name, + COALESCE(ms.from_phone, ds.from_phone, '') as from_phone, msg.sent_at, COALESCE(msg.size_estimate, 0) as size_estimate, COALESCE(msg.has_attachments, false) as has_attachments, @@ -1635,6 +1827,7 @@ func (e *DuckDBEngine) SearchFast(ctx context.Context, q *search.Query, filter M msg.deleted_from_source_at FROM msg LEFT JOIN msg_sender ms ON ms.message_id = msg.id + LEFT JOIN direct_sender ds ON ds.message_id = msg.id LEFT JOIN att ON att.message_id = msg.id LEFT JOIN msg_labels mlbl ON mlbl.message_id = msg.id LEFT JOIN conv c ON c.id = msg.conversation_id @@ -1666,6 +1859,7 @@ func (e *DuckDBEngine) SearchFast(ctx context.Context, q *search.Query, filter M &msg.Snippet, &msg.FromEmail, &msg.FromName, + &msg.FromPhone, &sentAt, &msg.SizeEstimate, &msg.HasAttachments, @@ -1702,15 +1896,29 @@ func (e *DuckDBEngine) SearchFastCount(ctx context.Context, q *search.Query, fil query := fmt.Sprintf(` WITH %s, msg_sender AS ( - SELECT mr.message_id, FIRST(p.email_address) as from_email, FIRST(p.display_name) as from_name + SELECT mr.message_id, + FIRST(p.email_address) as from_email, + FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name, + FIRST(COALESCE(p.phone_number, '')) as from_phone FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.recipient_type = 'from' GROUP BY mr.message_id + ), + direct_sender AS ( + SELECT msg.id as message_id, + COALESCE(p.email_address, '') as from_email, + COALESCE(p.display_name, '') as from_name, + COALESCE(p.phone_number, '') as from_phone + FROM msg + JOIN p ON p.id = msg.sender_id + WHERE msg.sender_id IS NOT NULL + AND msg.id NOT IN (SELECT message_id FROM msg_sender) ) SELECT COUNT(*) as cnt FROM msg LEFT JOIN msg_sender ms ON ms.message_id = msg.id + LEFT JOIN direct_sender ds ON ds.message_id = msg.id WHERE %s `, e.parquetCTEs(), strings.Join(conditions, " AND ")) @@ -1779,6 +1987,7 @@ func (e *DuckDBEngine) searchPageFromCache(ctx context.Context, limit, offset in sm.snippet, sm.from_email, sm.from_name, + COALESCE(sm.from_phone, '') as from_phone, sm.sent_at, sm.size_estimate, sm.has_attachments, @@ -1825,6 +2034,7 @@ func (e *DuckDBEngine) searchPageFromCache(ctx context.Context, limit, offset in &msg.Snippet, &msg.FromEmail, &msg.FromName, + &msg.FromPhone, &sentAt, &msg.SizeEstimate, &msg.HasAttachments, @@ -1946,11 +2156,22 @@ func (e *DuckDBEngine) SearchFastWithStats(ctx context.Context, q *search.Query, msg_sender AS ( SELECT mr.message_id, FIRST(p.email_address) as from_email, - FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name + FIRST(COALESCE(mr.display_name, p.display_name, '')) as from_name, + FIRST(COALESCE(p.phone_number, '')) as from_phone FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.recipient_type = 'from' GROUP BY mr.message_id + ), + direct_sender AS ( + SELECT msg.id as message_id, + COALESCE(p.email_address, '') as from_email, + COALESCE(p.display_name, '') as from_name, + COALESCE(p.phone_number, '') as from_phone + FROM msg + JOIN p ON p.id = msg.sender_id + WHERE msg.sender_id IS NOT NULL + AND msg.id NOT IN (SELECT message_id FROM msg_sender) ) SELECT msg.id, @@ -1958,8 +2179,9 @@ func (e *DuckDBEngine) SearchFastWithStats(ctx context.Context, q *search.Query, COALESCE(msg.conversation_id, 0) as conversation_id, COALESCE(msg.subject, '') as subject, COALESCE(msg.snippet, '') as snippet, - COALESCE(ms.from_email, '') as from_email, - COALESCE(ms.from_name, '') as from_name, + COALESCE(ms.from_email, ds.from_email, '') as from_email, + COALESCE(ms.from_name, ds.from_name, '') as from_name, + COALESCE(ms.from_phone, ds.from_phone, '') as from_phone, msg.sent_at, COALESCE(CAST(msg.size_estimate AS BIGINT), 0) as size_estimate, COALESCE(msg.has_attachments, false) as has_attachments, @@ -1967,6 +2189,7 @@ func (e *DuckDBEngine) SearchFastWithStats(ctx context.Context, q *search.Query, CAST(msg.source_id AS BIGINT) as source_id FROM msg LEFT JOIN msg_sender ms ON ms.message_id = msg.id + LEFT JOIN direct_sender ds ON ds.message_id = msg.id WHERE %s `, tempTable, e.parquetCTEs(), strings.Join(conditions, " AND ")) @@ -2022,24 +2245,35 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt if filter.HideDeletedFromSource { conditions = append(conditions, "msg.deleted_from_source_at IS NULL") } + // Sender filter - check both message_recipients (email/phone) and direct sender_id (WhatsApp/chat) if filter.Sender != "" { - conditions = append(conditions, "ms.from_email = ?") - args = append(args, filter.Sender) + conditions = append(conditions, `(EXISTS ( + SELECT 1 FROM mr + JOIN p ON p.id = mr.participant_id + WHERE mr.message_id = msg.id + AND mr.recipient_type = 'from' + AND (p.email_address = ? OR p.phone_number = ?) + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND (p.email_address = ? OR p.phone_number = ?) + ))`) + args = append(args, filter.Sender, filter.Sender, filter.Sender, filter.Sender) } if filter.Domain != "" { conditions = append(conditions, "ms.from_email ILIKE ?") args = append(args, "%@"+filter.Domain) } - // Recipient filter - use EXISTS subquery for drill-down context + // Recipient filter - use EXISTS subquery for drill-down context (checks email and phone) if filter.Recipient != "" { conditions = append(conditions, `EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type IN ('to', 'cc', 'bcc') - AND p.email_address = ? + AND (p.email_address = ? OR p.phone_number = ?) )`) - args = append(args, filter.Recipient) + args = append(args, filter.Recipient, filter.Recipient) } // Label filter - use EXISTS subquery for drill-down context if filter.Label != "" { @@ -2063,31 +2297,44 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt termPattern := "%" + escapeILIKE(term) + "%" conditions = append(conditions, `( msg.subject ILIKE ? ESCAPE '\' OR - ms.from_email ILIKE ? ESCAPE '\' OR - ms.from_name ILIKE ? ESCAPE '\' + COALESCE(ms.from_email, ds.from_email, '') ILIKE ? ESCAPE '\' OR + COALESCE(ms.from_name, ds.from_name, '') ILIKE ? ESCAPE '\' OR + COALESCE(ms.from_phone, ds.from_phone, '') ILIKE ? ESCAPE '\' )`) - args = append(args, termPattern, termPattern, termPattern) + args = append(args, termPattern, termPattern, termPattern, termPattern) } } - // From filter + // From filter - check email, phone, display name via message_recipients and direct sender_id if len(q.FromAddrs) > 0 { for _, addr := range q.FromAddrs { - conditions = append(conditions, "ms.from_email ILIKE ? ESCAPE '\\'") - args = append(args, "%"+escapeILIKE(addr)+"%") + pattern := "%" + escapeILIKE(addr) + "%" + conditions = append(conditions, `(EXISTS ( + SELECT 1 FROM mr + JOIN p ON p.id = mr.participant_id + WHERE mr.message_id = msg.id + AND mr.recipient_type = 'from' + AND (p.email_address ILIKE ? ESCAPE '\' OR p.phone_number ILIKE ? ESCAPE '\' OR p.display_name ILIKE ? ESCAPE '\') + ) OR EXISTS ( + SELECT 1 FROM p + WHERE p.id = msg.sender_id + AND (p.email_address ILIKE ? ESCAPE '\' OR p.phone_number ILIKE ? ESCAPE '\' OR p.display_name ILIKE ? ESCAPE '\') + ))`) + args = append(args, pattern, pattern, pattern, pattern, pattern, pattern) } } - // To filter - use EXISTS subquery to check recipients + // To filter - use EXISTS subquery to check recipients (email and phone) if len(q.ToAddrs) > 0 { for _, addr := range q.ToAddrs { + pattern := "%" + escapeILIKE(addr) + "%" conditions = append(conditions, `EXISTS ( SELECT 1 FROM mr JOIN p ON p.id = mr.participant_id WHERE mr.message_id = msg.id AND mr.recipient_type IN ('to', 'cc', 'bcc') - AND p.email_address ILIKE ? ESCAPE '\' + AND (p.email_address ILIKE ? ESCAPE '\' OR p.phone_number ILIKE ? ESCAPE '\') )`) - args = append(args, "%"+escapeILIKE(addr)+"%") + args = append(args, pattern, pattern) } } diff --git a/internal/query/duckdb_test.go b/internal/query/duckdb_test.go index bc57c25b..b4fb4d4c 100644 --- a/internal/query/duckdb_test.go +++ b/internal/query/duckdb_test.go @@ -2030,7 +2030,7 @@ func TestBuildSearchConditions_EscapedWildcards(t *testing.T) { query: &search.Query{ FromAddrs: []string{"test_user%"}, }, - wantClauses: []string{"ms.from_email ILIKE", "ESCAPE"}, + wantClauses: []string{"p.email_address ILIKE", "ESCAPE"}, wantInArgs: []string{"test\\_user\\%"}, }, { @@ -2156,16 +2156,16 @@ func TestDuckDBEngine_AggregateByRecipientName_EmptyStringFallback(t *testing.T) // Build Parquet data with empty-string and whitespace display_names on recipients engine := createEngineFromBuilder(t, newParquetBuilder(t). addTable("messages", "messages/year=2024", "data.parquet", messagesCols, ` - (1::BIGINT, 1::BIGINT, 'msg1', 100::BIGINT, 'Hello', 'Snippet', TIMESTAMP '2024-01-15 10:00:00', 1000::BIGINT, false, NULL::TIMESTAMP, 2024, 1), - (2::BIGINT, 1::BIGINT, 'msg2', 101::BIGINT, 'World', 'Snippet', TIMESTAMP '2024-01-16 10:00:00', 1000::BIGINT, false, NULL::TIMESTAMP, 2024, 1) + (1::BIGINT, 1::BIGINT, 'msg1', 100::BIGINT, 'Hello', 'Snippet', TIMESTAMP '2024-01-15 10:00:00', 1000::BIGINT, false, 0, NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1), + (2::BIGINT, 1::BIGINT, 'msg2', 101::BIGINT, 'World', 'Snippet', TIMESTAMP '2024-01-16 10:00:00', 1000::BIGINT, false, 0, NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1) `). addTable("sources", "sources", "sources.parquet", sourcesCols, ` (1::BIGINT, 'test@gmail.com') `). addTable("participants", "participants", "participants.parquet", participantsCols, ` - (1::BIGINT, 'sender@test.com', 'test.com', 'Sender'), - (2::BIGINT, 'empty@test.com', 'test.com', ''), - (3::BIGINT, 'spaces@test.com', 'test.com', ' ') + (1::BIGINT, 'sender@test.com', 'test.com', 'Sender', ''), + (2::BIGINT, 'empty@test.com', 'test.com', '', ''), + (3::BIGINT, 'spaces@test.com', 'test.com', ' ', '') `). addTable("message_recipients", "message_recipients", "message_recipients.parquet", messageRecipientsCols, ` (1::BIGINT, 1::BIGINT, 'from', 'Sender'), @@ -2177,8 +2177,8 @@ func TestDuckDBEngine_AggregateByRecipientName_EmptyStringFallback(t *testing.T) addEmptyTable("message_labels", "message_labels", "message_labels.parquet", messageLabelsCols, `(1::BIGINT, 1::BIGINT)`). addEmptyTable("attachments", "attachments", "attachments.parquet", attachmentsCols, `(1::BIGINT, 100::BIGINT, 'x')`). addTable("conversations", "conversations", "conversations.parquet", conversationsCols, ` - (100::BIGINT, 'thread100'), - (101::BIGINT, 'thread101') + (100::BIGINT, 'thread100', ''), + (101::BIGINT, 'thread101', '') `)) ctx := context.Background() @@ -2208,15 +2208,15 @@ func TestDuckDBEngine_ListMessages_MatchEmptyRecipientName(t *testing.T) { // Build Parquet data with a message that has no recipients engine := createEngineFromBuilder(t, newParquetBuilder(t). addTable("messages", "messages/year=2024", "data.parquet", messagesCols, ` - (1::BIGINT, 1::BIGINT, 'msg1', 100::BIGINT, 'Has Recipient', 'Snippet', TIMESTAMP '2024-01-15 10:00:00', 1000::BIGINT, false, NULL::TIMESTAMP, 2024, 1), - (2::BIGINT, 1::BIGINT, 'msg2', 101::BIGINT, 'No Recipient', 'Snippet', TIMESTAMP '2024-01-16 10:00:00', 1000::BIGINT, false, NULL::TIMESTAMP, 2024, 1) + (1::BIGINT, 1::BIGINT, 'msg1', 100::BIGINT, 'Has Recipient', 'Snippet', TIMESTAMP '2024-01-15 10:00:00', 1000::BIGINT, false, 0, NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1), + (2::BIGINT, 1::BIGINT, 'msg2', 101::BIGINT, 'No Recipient', 'Snippet', TIMESTAMP '2024-01-16 10:00:00', 1000::BIGINT, false, 0, NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1) `). addTable("sources", "sources", "sources.parquet", sourcesCols, ` (1::BIGINT, 'test@gmail.com') `). addTable("participants", "participants", "participants.parquet", participantsCols, ` - (1::BIGINT, 'alice@test.com', 'test.com', 'Alice'), - (2::BIGINT, 'bob@test.com', 'test.com', 'Bob') + (1::BIGINT, 'alice@test.com', 'test.com', 'Alice', ''), + (2::BIGINT, 'bob@test.com', 'test.com', 'Bob', '') `). addTable("message_recipients", "message_recipients", "message_recipients.parquet", messageRecipientsCols, ` (1::BIGINT, 1::BIGINT, 'from', 'Alice'), @@ -2226,8 +2226,8 @@ func TestDuckDBEngine_ListMessages_MatchEmptyRecipientName(t *testing.T) { addEmptyTable("message_labels", "message_labels", "message_labels.parquet", messageLabelsCols, `(1::BIGINT, 1::BIGINT)`). addEmptyTable("attachments", "attachments", "attachments.parquet", attachmentsCols, `(1::BIGINT, 100::BIGINT, 'x')`). addTable("conversations", "conversations", "conversations.parquet", conversationsCols, ` - (100::BIGINT, 'thread100'), - (101::BIGINT, 'thread101') + (100::BIGINT, 'thread100', ''), + (101::BIGINT, 'thread101', '') `)) ctx := context.Background() @@ -2805,14 +2805,14 @@ func TestDuckDBEngine_VARCHARParquetColumns(t *testing.T) { // string, to reproduce type mismatches in COALESCE, JOINs, and TRY_CAST paths. engine := createEngineFromBuilder(t, newParquetBuilder(t). addTable("messages", "messages/year=2024", "data.parquet", messagesCols, ` - (1::BIGINT, 1::BIGINT, 'msg1', '100', 'Hello World', 'snippet1', TIMESTAMP '2024-01-15 10:00:00', '1000', '0', NULL::TIMESTAMP, 2024, 1), - (2::BIGINT, 1::BIGINT, 'msg2', '101', 'Goodbye', 'snippet2', TIMESTAMP '2024-01-16 10:00:00', '2000', '1', NULL::TIMESTAMP, 2024, 1) + (1::BIGINT, 1::BIGINT, 'msg1', '100', 'Hello World', 'snippet1', TIMESTAMP '2024-01-15 10:00:00', '1000', '0', '0', NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1), + (2::BIGINT, 1::BIGINT, 'msg2', '101', 'Goodbye', 'snippet2', TIMESTAMP '2024-01-16 10:00:00', '2000', '1', '0', NULL::TIMESTAMP, NULL::BIGINT, 'email', 2024, 1) `). addTable("sources", "sources", "sources.parquet", sourcesCols, ` (1::BIGINT, 'test@gmail.com') `). addTable("participants", "participants", "participants.parquet", participantsCols, ` - (1::BIGINT, 'alice@test.com', 'test.com', 'Alice') + (1::BIGINT, 'alice@test.com', 'test.com', 'Alice', '') `). addTable("message_recipients", "message_recipients", "message_recipients.parquet", messageRecipientsCols, ` (1::BIGINT, 1::BIGINT, 'from', 'Alice'), @@ -2822,8 +2822,8 @@ func TestDuckDBEngine_VARCHARParquetColumns(t *testing.T) { addEmptyTable("message_labels", "message_labels", "message_labels.parquet", messageLabelsCols, `(1::BIGINT, 1::BIGINT)`). addEmptyTable("attachments", "attachments", "attachments.parquet", attachmentsCols, `(1::BIGINT, '100', 'x')`). addTable("conversations", "conversations", "conversations.parquet", conversationsCols, ` - (100::BIGINT, 'thread100'), - (101::BIGINT, 'thread101') + (100::BIGINT, 'thread100', ''), + (101::BIGINT, 'thread101', '') `)) ctx := context.Background() @@ -3208,3 +3208,109 @@ func TestDuckDBEngine_HideDeletedFromSource(t *testing.T) { t.Errorf("GetTotalStats with hide-deleted: expected 2 messages, got %d", stats.MessageCount) } } + +// TestDuckDBEngine_StaleParquetSchema verifies that a DuckDB engine can query +// Parquet files written BEFORE PR #160 added phone_number, attachment_count, +// sender_id, and message_type columns. The engine should synthesise sensible +// defaults instead of failing with a binder error. +func TestDuckDBEngine_StaleParquetSchema(t *testing.T) { + // Old-style column definitions (pre-WhatsApp). + const oldMessagesCols = "id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, deleted_from_source_at, year, month" + const oldParticipantsCols = "id, email_address, domain, display_name" + const oldConversationsCols = "id, source_conversation_id" + + engine := createEngineFromBuilder(t, newParquetBuilder(t). + addTable("messages", "messages/year=2024", "data.parquet", oldMessagesCols, ` + (1::BIGINT, 1::BIGINT, 'msg1', 100::BIGINT, 'Stale Hello', 'snip1', TIMESTAMP '2024-01-15 10:00:00', 1000::BIGINT, false, NULL::TIMESTAMP, 2024, 1), + (2::BIGINT, 1::BIGINT, 'msg2', 101::BIGINT, 'Stale Goodbye', 'snip2', TIMESTAMP '2024-01-16 10:00:00', 2000::BIGINT, true, NULL::TIMESTAMP, 2024, 1) + `). + addTable("sources", "sources", "sources.parquet", sourcesCols, ` + (1::BIGINT, 'test@gmail.com') + `). + addTable("participants", "participants", "participants.parquet", oldParticipantsCols, ` + (1::BIGINT, 'alice@test.com', 'test.com', 'Alice') + `). + addTable("message_recipients", "message_recipients", "message_recipients.parquet", messageRecipientsCols, ` + (1::BIGINT, 1::BIGINT, 'from', 'Alice'), + (2::BIGINT, 1::BIGINT, 'from', 'Alice') + `). + addEmptyTable("labels", "labels", "labels.parquet", labelsCols, `(1::BIGINT, 'x')`). + addEmptyTable("message_labels", "message_labels", "message_labels.parquet", messageLabelsCols, `(1::BIGINT, 1::BIGINT)`). + addEmptyTable("attachments", "attachments", "attachments.parquet", attachmentsCols, `(1::BIGINT, 100::BIGINT, 'x')`). + addTable("conversations", "conversations", "conversations.parquet", oldConversationsCols, ` + (100::BIGINT, 'thread100'), + (101::BIGINT, 'thread101') + `)) + + ctx := context.Background() + + t.Run("ListMessages", func(t *testing.T) { + results, err := engine.ListMessages(ctx, MessageFilter{}) + if err != nil { + t.Fatalf("ListMessages with stale Parquet schema: %v", err) + } + if len(results) != 2 { + t.Fatalf("expected 2 messages, got %d", len(results)) + } + }) + + t.Run("SearchFast", func(t *testing.T) { + q := search.Parse("Stale Hello") + results, err := engine.SearchFast(ctx, q, MessageFilter{}, 100, 0) + if err != nil { + t.Fatalf("SearchFast with stale Parquet schema: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + if results[0].Subject != "Stale Hello" { + t.Fatalf("unexpected subject: %s", results[0].Subject) + } + }) + + t.Run("SearchFastCount", func(t *testing.T) { + q := search.Parse("Stale") + count, err := engine.SearchFastCount(ctx, q, MessageFilter{}) + if err != nil { + t.Fatalf("SearchFastCount with stale Parquet schema: %v", err) + } + if count != 2 { + t.Fatalf("expected count 2, got %d", count) + } + }) + + t.Run("Aggregate", func(t *testing.T) { + results, err := engine.Aggregate(ctx, ViewSenders, DefaultAggregateOptions()) + if err != nil { + t.Fatalf("Aggregate with stale Parquet schema: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 sender, got %d", len(results)) + } + }) + + t.Run("GetTotalStats", func(t *testing.T) { + stats, err := engine.GetTotalStats(ctx, StatsOptions{}) + if err != nil { + t.Fatalf("GetTotalStats with stale Parquet schema: %v", err) + } + if stats.MessageCount != 2 { + t.Fatalf("expected 2 messages, got %d", stats.MessageCount) + } + }) + + // Verify that optionalCols correctly detected the missing columns. + t.Run("ProbeDetectedMissing", func(t *testing.T) { + for _, col := range []struct{ table, col string }{ + {"participants", "phone_number"}, + {"messages", "attachment_count"}, + {"messages", "sender_id"}, + {"messages", "message_type"}, + {"conversations", "title"}, + } { + if engine.hasCol(col.table, col.col) { + t.Errorf("expected %s.%s to be detected as missing", col.table, col.col) + } + } + }) +} diff --git a/internal/query/models.go b/internal/query/models.go index 0ce4a813..2f12cdf4 100644 --- a/internal/query/models.go +++ b/internal/query/models.go @@ -28,12 +28,15 @@ type MessageSummary struct { Snippet string FromEmail string FromName string + FromPhone string // Phone number (for WhatsApp/chat sources) SentAt time.Time SizeEstimate int64 HasAttachments bool AttachmentCount int Labels []string DeletedAt *time.Time // When message was deleted from server (nil if not deleted) + MessageType string // e.g., "email", "whatsapp" — from messages.message_type + ConversationTitle string // Group/chat name from conversations.title } // MessageDetail represents a full message with body and attachments. diff --git a/internal/query/sqlite.go b/internal/query/sqlite.go index 4f482305..0de83f26 100644 --- a/internal/query/sqlite.go +++ b/internal/query/sqlite.go @@ -279,40 +279,56 @@ func buildFilterJoinsAndConditions(filter MessageFilter, tableAlias string) (str conditions = append(conditions, prefix+"deleted_from_source_at IS NULL") } - // Sender filter + // Sender filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) + // Also checks phone_number for phone-based lookups (e.g., from:+447...) if filter.Sender != "" { joins = append(joins, ` - JOIN message_recipients mr_filter_from ON mr_filter_from.message_id = m.id AND mr_filter_from.recipient_type = 'from' - JOIN participants p_filter_from ON p_filter_from.id = mr_filter_from.participant_id + LEFT JOIN message_recipients mr_filter_from ON mr_filter_from.message_id = m.id AND mr_filter_from.recipient_type = 'from' + LEFT JOIN participants p_filter_from ON p_filter_from.id = mr_filter_from.participant_id + LEFT JOIN participants p_direct_sender ON p_direct_sender.id = m.sender_id `) - conditions = append(conditions, "p_filter_from.email_address = ?") - args = append(args, filter.Sender) + conditions = append(conditions, "(p_filter_from.email_address = ? OR p_filter_from.phone_number = ? OR p_direct_sender.email_address = ? OR p_direct_sender.phone_number = ?)") + args = append(args, filter.Sender, filter.Sender, filter.Sender, filter.Sender) } else if filter.MatchesEmpty(ViewSenders) { + // A message has an "empty sender" only if it has no from-recipient AND no direct sender_id. joins = append(joins, ` LEFT JOIN message_recipients mr_filter_from ON mr_filter_from.message_id = m.id AND mr_filter_from.recipient_type = 'from' LEFT JOIN participants p_filter_from ON p_filter_from.id = mr_filter_from.participant_id + LEFT JOIN participants p_direct_sender ON p_direct_sender.id = m.sender_id `) - conditions = append(conditions, "(mr_filter_from.id IS NULL OR p_filter_from.email_address IS NULL OR p_filter_from.email_address = '')") + conditions = append(conditions, `((mr_filter_from.id IS NULL OR ( + (p_filter_from.email_address IS NULL OR p_filter_from.email_address = '') AND + (p_filter_from.phone_number IS NULL OR p_filter_from.phone_number = '') + )) AND m.sender_id IS NULL)`) } - // Sender name filter + // Sender name filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) if filter.SenderName != "" { if filter.Sender == "" && !filter.MatchesEmpty(ViewSenders) { joins = append(joins, ` - JOIN message_recipients mr_filter_from ON mr_filter_from.message_id = m.id AND mr_filter_from.recipient_type = 'from' - JOIN participants p_filter_from ON p_filter_from.id = mr_filter_from.participant_id + LEFT JOIN message_recipients mr_filter_from ON mr_filter_from.message_id = m.id AND mr_filter_from.recipient_type = 'from' + LEFT JOIN participants p_filter_from ON p_filter_from.id = mr_filter_from.participant_id + LEFT JOIN participants p_direct_sender ON p_direct_sender.id = m.sender_id `) } - conditions = append(conditions, "COALESCE(NULLIF(TRIM(p_filter_from.display_name), ''), p_filter_from.email_address) = ?") - args = append(args, filter.SenderName) + conditions = append(conditions, `( + COALESCE(NULLIF(TRIM(p_filter_from.display_name), ''), p_filter_from.email_address) = ? + OR COALESCE(NULLIF(TRIM(p_direct_sender.display_name), ''), p_direct_sender.email_address) = ? + )`) + args = append(args, filter.SenderName, filter.SenderName) } else if filter.MatchesEmpty(ViewSenderNames) { - conditions = append(conditions, `NOT EXISTS ( + // A message has an "empty sender name" only if it has no from-recipient name AND no direct sender_id with a name. + conditions = append(conditions, `(NOT EXISTS ( SELECT 1 FROM message_recipients mr_sn JOIN participants p_sn ON p_sn.id = mr_sn.participant_id WHERE mr_sn.message_id = m.id AND mr_sn.recipient_type = 'from' AND COALESCE(NULLIF(TRIM(p_sn.display_name), ''), p_sn.email_address) IS NOT NULL - )`) + ) AND NOT EXISTS ( + SELECT 1 FROM participants p_ds + WHERE p_ds.id = m.sender_id + AND COALESCE(NULLIF(TRIM(p_ds.display_name), ''), p_ds.email_address) IS NOT NULL + ))`) } // Recipient filter @@ -599,14 +615,17 @@ func (e *SQLiteEngine) ListMessages(ctx context.Context, filter MessageFilter) ( COALESCE(m.snippet, ''), COALESCE(p_sender.email_address, ''), COALESCE(p_sender.display_name, ''), + COALESCE(p_sender.phone_number, ''), m.sent_at, COALESCE(m.size_estimate, 0), m.has_attachments, m.attachment_count, - m.deleted_from_source_at + m.deleted_from_source_at, + COALESCE(m.message_type, ''), + COALESCE(conv.title, '') FROM messages m LEFT JOIN message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from' - LEFT JOIN participants p_sender ON p_sender.id = mr_sender.participant_id + LEFT JOIN participants p_sender ON p_sender.id = COALESCE(mr_sender.participant_id, m.sender_id) LEFT JOIN conversations conv ON conv.id = m.conversation_id %s WHERE %s @@ -636,11 +655,14 @@ func (e *SQLiteEngine) ListMessages(ctx context.Context, filter MessageFilter) ( &msg.Snippet, &msg.FromEmail, &msg.FromName, + &msg.FromPhone, &sentAt, &msg.SizeEstimate, &msg.HasAttachments, &msg.AttachmentCount, &deletedAt, + &msg.MessageType, + &msg.ConversationTitle, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } @@ -878,24 +900,33 @@ func (e *SQLiteEngine) GetGmailIDsByFilter(ctx context.Context, filter MessageFi // Build JOIN clauses based on filter type var joins []string + // Scope to Gmail sources only — this function is used for Gmail-specific + // deletion/staging workflows and must not return WhatsApp or other source IDs. + joins = append(joins, `JOIN sources s_gmail ON s_gmail.id = m.source_id AND s_gmail.source_type = 'gmail'`) + if filter.Sender != "" { joins = append(joins, ` - JOIN message_recipients mr_from ON mr_from.message_id = m.id AND mr_from.recipient_type = 'from' - JOIN participants p_from ON p_from.id = mr_from.participant_id + LEFT JOIN message_recipients mr_from ON mr_from.message_id = m.id AND mr_from.recipient_type = 'from' + LEFT JOIN participants p_from ON p_from.id = mr_from.participant_id + LEFT JOIN participants p_ds ON p_ds.id = m.sender_id `) - conditions = append(conditions, "p_from.email_address = ?") - args = append(args, filter.Sender) + conditions = append(conditions, "(p_from.email_address = ? OR p_from.phone_number = ? OR p_ds.email_address = ? OR p_ds.phone_number = ?)") + args = append(args, filter.Sender, filter.Sender, filter.Sender, filter.Sender) } if filter.SenderName != "" { if filter.Sender == "" { joins = append(joins, ` - JOIN message_recipients mr_from ON mr_from.message_id = m.id AND mr_from.recipient_type = 'from' - JOIN participants p_from ON p_from.id = mr_from.participant_id + LEFT JOIN message_recipients mr_from ON mr_from.message_id = m.id AND mr_from.recipient_type = 'from' + LEFT JOIN participants p_from ON p_from.id = mr_from.participant_id + LEFT JOIN participants p_ds ON p_ds.id = m.sender_id `) } - conditions = append(conditions, "COALESCE(NULLIF(TRIM(p_from.display_name), ''), p_from.email_address) = ?") - args = append(args, filter.SenderName) + conditions = append(conditions, `( + COALESCE(NULLIF(TRIM(p_from.display_name), ''), p_from.email_address) = ? + OR COALESCE(NULLIF(TRIM(p_ds.display_name), ''), p_ds.email_address) = ? + )`) + args = append(args, filter.SenderName, filter.SenderName) } if filter.Recipient != "" { @@ -1193,14 +1224,17 @@ func (e *SQLiteEngine) executeSearchQuery(ctx context.Context, conditions []stri COALESCE(m.snippet, ''), COALESCE(p_sender.email_address, ''), COALESCE(p_sender.display_name, ''), + COALESCE(p_sender.phone_number, ''), m.sent_at, COALESCE(m.size_estimate, 0), m.has_attachments, m.attachment_count, - m.deleted_from_source_at + m.deleted_from_source_at, + COALESCE(m.message_type, ''), + COALESCE(conv.title, '') FROM messages m LEFT JOIN message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from' - LEFT JOIN participants p_sender ON p_sender.id = mr_sender.participant_id + LEFT JOIN participants p_sender ON p_sender.id = COALESCE(mr_sender.participant_id, m.sender_id) LEFT JOIN conversations conv ON conv.id = m.conversation_id %s %s @@ -1231,11 +1265,14 @@ func (e *SQLiteEngine) executeSearchQuery(ctx context.Context, conditions []stri &msg.Snippet, &msg.FromEmail, &msg.FromName, + &msg.FromPhone, &sentAt, &msg.SizeEstimate, &msg.HasAttachments, &msg.AttachmentCount, &deletedAt, + &msg.MessageType, + &msg.ConversationTitle, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } diff --git a/internal/query/testfixtures_test.go b/internal/query/testfixtures_test.go index b501af13..256ea6be 100644 --- a/internal/query/testfixtures_test.go +++ b/internal/query/testfixtures_test.go @@ -27,7 +27,10 @@ type MessageFixture struct { SentAt time.Time SizeEstimate int64 HasAttachments bool + AttachmentCount int DeletedAt *time.Time // nil = NULL + SenderID *int64 // nil = NULL (direct sender for WhatsApp/chat messages) + MessageType string // e.g. "email", "whatsapp" Year int Month int } @@ -44,6 +47,7 @@ type ParticipantFixture struct { Email string Domain string DisplayName string + PhoneNumber string // E.164 phone number (for WhatsApp/chat participants) } // RecipientFixture defines a message_recipients row for Parquet test data. @@ -77,6 +81,7 @@ type AttachmentFixture struct { type ConversationFixture struct { ID int64 SourceConversationID string + Title string // Group/chat name (for WhatsApp/chat conversations) } // --------------------------------------------------------------------------- @@ -292,11 +297,19 @@ func (m MessageFixture) toSQL() string { if m.DeletedAt != nil { deletedAt = fmt.Sprintf("TIMESTAMP '%s'", m.DeletedAt.Format("2006-01-02 15:04:05")) } - return fmt.Sprintf("(%d::BIGINT, %d::BIGINT, %s, %d::BIGINT, %s, %s, TIMESTAMP '%s', %d::BIGINT, %v, %s, %d, %d)", + senderID := "NULL::BIGINT" + if m.SenderID != nil { + senderID = fmt.Sprintf("%d::BIGINT", *m.SenderID) + } + msgType := m.MessageType + if msgType == "" { + msgType = "email" + } + return fmt.Sprintf("(%d::BIGINT, %d::BIGINT, %s, %d::BIGINT, %s, %s, TIMESTAMP '%s', %d::BIGINT, %v, %d, %s, %s, %s, %d, %d)", m.ID, m.SourceID, sqlStr(m.SourceMessageID), m.ConversationID, sqlStr(m.Subject), sqlStr(m.Snippet), m.SentAt.Format("2006-01-02 15:04:05"), m.SizeEstimate, - m.HasAttachments, deletedAt, m.Year, m.Month, + m.HasAttachments, m.AttachmentCount, deletedAt, senderID, sqlStr(msgType), m.Year, m.Month, ) } @@ -308,8 +321,8 @@ func (b *TestDataBuilder) sourcesSQL() string { func (b *TestDataBuilder) participantsSQL() string { return joinRows(b.participants, func(p ParticipantFixture) string { - return fmt.Sprintf("(%d::BIGINT, %s, %s, %s)", - p.ID, sqlStr(p.Email), sqlStr(p.Domain), sqlStr(p.DisplayName)) + return fmt.Sprintf("(%d::BIGINT, %s, %s, %s, %s)", + p.ID, sqlStr(p.Email), sqlStr(p.Domain), sqlStr(p.DisplayName), sqlStr(p.PhoneNumber)) }) } @@ -341,8 +354,8 @@ func (b *TestDataBuilder) attachmentsSQL() string { func (b *TestDataBuilder) conversationsSQL() string { return joinRows(b.conversations, func(c ConversationFixture) string { - return fmt.Sprintf("(%d::BIGINT, %s)", - c.ID, sqlStr(c.SourceConversationID)) + return fmt.Sprintf("(%d::BIGINT, %s, %s)", + c.ID, sqlStr(c.SourceConversationID), sqlStr(c.Title)) }) } @@ -352,14 +365,14 @@ func (b *TestDataBuilder) conversationsSQL() string { // column definitions (coupled to SQL generation methods above) const ( - messagesCols = "id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, deleted_from_source_at, year, month" + messagesCols = "id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, attachment_count, deleted_from_source_at, sender_id, message_type, year, month" sourcesCols = "id, account_email" - participantsCols = "id, email_address, domain, display_name" + participantsCols = "id, email_address, domain, display_name, phone_number" messageRecipientsCols = "message_id, participant_id, recipient_type, display_name" labelsCols = "id, name" messageLabelsCols = "message_id, label_id" attachmentsCols = "message_id, size, filename" - conversationsCols = "id, source_conversation_id" + conversationsCols = "id, source_conversation_id, title" ) // Build generates Parquet files from the accumulated data and returns the @@ -396,11 +409,11 @@ func (b *TestDataBuilder) addAuxiliaryTables(pb *parquetBuilder) { empty bool }{ {"sources", "sources", "sources.parquet", sourcesCols, "(0::BIGINT, '')", b.sourcesSQL(), len(b.sources) == 0}, - {"participants", "participants", "participants.parquet", participantsCols, "(0::BIGINT, '', '', '')", b.participantsSQL(), len(b.participants) == 0}, + {"participants", "participants", "participants.parquet", participantsCols, "(0::BIGINT, '', '', '', '')", b.participantsSQL(), len(b.participants) == 0}, {"message_recipients", "message_recipients", "message_recipients.parquet", messageRecipientsCols, "(0::BIGINT, 0::BIGINT, '', '')", b.recipientsSQL(), len(b.recipients) == 0}, {"labels", "labels", "labels.parquet", labelsCols, "(0::BIGINT, '')", b.labelsSQL(), len(b.labels) == 0}, {"message_labels", "message_labels", "message_labels.parquet", messageLabelsCols, "(0::BIGINT, 0::BIGINT)", b.messageLabelsSQL(), len(b.msgLabels) == 0}, - {"conversations", "conversations", "conversations.parquet", conversationsCols, "(0::BIGINT, '')", b.conversationsSQL(), len(b.conversations) == 0}, + {"conversations", "conversations", "conversations.parquet", conversationsCols, "(0::BIGINT, '', '')", b.conversationsSQL(), len(b.conversations) == 0}, } for _, a := range auxTables { if a.empty { diff --git a/internal/store/messages.go b/internal/store/messages.go index 8e836af8..53ce059a 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -818,6 +818,174 @@ func (s *Store) backfillFTSBatch(fromID, toID int64) (int64, error) { return result.RowsAffected() } +// EnsureConversationWithType gets or creates a conversation with an explicit conversation_type. +// Unlike EnsureConversation (which hardcodes 'email_thread'), this accepts the type as a parameter, +// making it suitable for WhatsApp and other messaging platforms. +func (s *Store) EnsureConversationWithType(sourceID int64, sourceConversationID, conversationType, title string) (int64, error) { + // Try to get existing + var id int64 + err := s.db.QueryRow(` + SELECT id FROM conversations + WHERE source_id = ? AND source_conversation_id = ? + `, sourceID, sourceConversationID).Scan(&id) + + if err == nil { + // Update conversation_type and title if they've changed. + // Only update title when the new value is non-empty (don't blank out existing titles). + if title != "" { + _, _ = s.db.Exec(` + UPDATE conversations SET conversation_type = ?, title = ?, updated_at = datetime('now') + WHERE id = ? AND (conversation_type != ? OR title != ? OR title IS NULL) + `, conversationType, title, id, conversationType, title) + } else { + _, _ = s.db.Exec(` + UPDATE conversations SET conversation_type = ?, updated_at = datetime('now') + WHERE id = ? AND conversation_type != ? + `, conversationType, id, conversationType) + } + return id, nil + } + if err != sql.ErrNoRows { + return 0, err + } + + // Create new + result, err := s.db.Exec(` + INSERT INTO conversations (source_id, source_conversation_id, conversation_type, title, created_at, updated_at) + VALUES (?, ?, ?, ?, datetime('now'), datetime('now')) + `, sourceID, sourceConversationID, conversationType, title) + if err != nil { + return 0, err + } + + return result.LastInsertId() +} + +// EnsureParticipantByPhone gets or creates a participant by phone number. +// Phone must start with "+" (E.164 format). Returns an error for empty or +// invalid phone numbers to prevent database pollution. +// Also creates a participant_identifiers row with identifier_type='whatsapp'. +func (s *Store) EnsureParticipantByPhone(phone, displayName string) (int64, error) { + if phone == "" { + return 0, fmt.Errorf("phone number is required") + } + if !strings.HasPrefix(phone, "+") { + return 0, fmt.Errorf("phone number must be in E.164 format (starting with +), got %q", phone) + } + + // Try to get existing by phone + var id int64 + err := s.db.QueryRow(` + SELECT id FROM participants WHERE phone_number = ? + `, phone).Scan(&id) + + if err == nil { + // Update display name if provided and currently empty + if displayName != "" { + s.db.Exec(` + UPDATE participants SET display_name = ? + WHERE id = ? AND (display_name IS NULL OR display_name = '') + `, displayName, id) //nolint:errcheck // best-effort display name update + } + return id, nil + } + if err != sql.ErrNoRows { + return 0, err + } + + // Create new participant + result, err := s.db.Exec(` + INSERT INTO participants (phone_number, display_name, created_at, updated_at) + VALUES (?, ?, datetime('now'), datetime('now')) + `, phone, displayName) + if err != nil { + return 0, fmt.Errorf("insert participant: %w", err) + } + + id, err = result.LastInsertId() + if err != nil { + return 0, err + } + + // Also create a participant_identifiers row + _, err = s.db.Exec(` + INSERT OR IGNORE INTO participant_identifiers (participant_id, identifier_type, identifier_value, is_primary) + VALUES (?, 'whatsapp', ?, TRUE) + `, id, phone) + if err != nil { + return 0, fmt.Errorf("insert participant identifier: %w", err) + } + + return id, nil +} + +// UpdateParticipantDisplayNameByPhone updates the display_name for an existing +// participant identified by phone number. Only updates if display_name is currently +// empty. Returns true if a participant was found and updated, false if not found +// or name was already set. Does NOT create new participants. +func (s *Store) UpdateParticipantDisplayNameByPhone(phone, displayName string) (bool, error) { + if phone == "" || displayName == "" { + return false, nil + } + + result, err := s.db.Exec(` + UPDATE participants SET display_name = ?, updated_at = datetime('now') + WHERE phone_number = ? AND (display_name IS NULL OR display_name = '') + `, displayName, phone) + if err != nil { + return false, err + } + + rows, err := result.RowsAffected() + if err != nil { + return false, err + } + return rows > 0, nil +} + +// EnsureConversationParticipant adds a participant to a conversation. +// Uses INSERT OR IGNORE to be idempotent. +func (s *Store) EnsureConversationParticipant(conversationID, participantID int64, role string) error { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO conversation_participants (conversation_id, participant_id, role, joined_at) + VALUES (?, ?, ?, datetime('now')) + `, conversationID, participantID, role) + return err +} + +// UpsertReaction inserts or ignores a reaction. +func (s *Store) UpsertReaction(messageID, participantID int64, reactionType, reactionValue string, createdAt time.Time) error { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO reactions (message_id, participant_id, reaction_type, reaction_value, created_at) + VALUES (?, ?, ?, ?, ?) + `, messageID, participantID, reactionType, reactionValue, createdAt) + return err +} + +// UpsertMessageRawWithFormat stores compressed raw data with an explicit format. +// Unlike UpsertMessageRaw (which hardcodes 'mime'), this accepts the format as a parameter. +func (s *Store) UpsertMessageRawWithFormat(messageID int64, rawData []byte, format string) error { + // Compress with zlib + var compressed bytes.Buffer + w := zlib.NewWriter(&compressed) + if _, err := w.Write(rawData); err != nil { + return fmt.Errorf("compress: %w", err) + } + if err := w.Close(); err != nil { + return fmt.Errorf("close compressor: %w", err) + } + + _, err := s.db.Exec(` + INSERT INTO message_raw (message_id, raw_data, raw_format, compression) + VALUES (?, ?, ?, 'zlib') + ON CONFLICT(message_id) DO UPDATE SET + raw_data = excluded.raw_data, + raw_format = excluded.raw_format, + compression = excluded.compression + `, messageID, compressed.Bytes(), format) + return err +} + // UpsertAttachment stores an attachment record. func (s *Store) UpsertAttachment(messageID int64, filename, mimeType, storagePath, contentHash string, size int) error { // Check if attachment already exists (by message_id and content_hash) diff --git a/internal/store/store.go b/internal/store/store.go index 5cadd866..e7bf31c1 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -226,6 +226,29 @@ func (s *Store) InitSchema() error { return fmt.Errorf("execute schema.sql: %w", err) } + // Migrate existing databases that were created before newer columns + // were added to schema.sql. CREATE TABLE IF NOT EXISTS is a no-op for + // existing tables, so we must ALTER TABLE to add missing columns. + // SQLite returns "duplicate column name" for columns that already exist; + // we silently ignore those errors. + migrations := []string{ + "ALTER TABLE participants ADD COLUMN phone_number TEXT", + "ALTER TABLE participants ADD COLUMN canonical_id TEXT", + "ALTER TABLE messages ADD COLUMN sender_id INTEGER REFERENCES participants(id)", + "ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'", + "ALTER TABLE messages ADD COLUMN attachment_count INTEGER DEFAULT 0", + "ALTER TABLE messages ADD COLUMN deleted_from_source_at DATETIME", + "ALTER TABLE messages ADD COLUMN delete_batch_id TEXT", + "ALTER TABLE conversations ADD COLUMN title TEXT", + } + for _, m := range migrations { + if _, err := s.db.Exec(m); err != nil { + if !isSQLiteError(err, "duplicate column name") { + return fmt.Errorf("migration %q: %w", m, err) + } + } + } + // Try to load and execute SQLite-specific schema (FTS5) // This is optional - FTS5 may not be available in all builds sqliteSchema, err := schemaFS.ReadFile("schema_sqlite.sql") diff --git a/internal/textutil/encoding.go b/internal/textutil/encoding.go index 55ebe55d..34c74d5b 100644 --- a/internal/textutil/encoding.go +++ b/internal/textutil/encoding.go @@ -146,3 +146,76 @@ func FirstLine(s string) string { } return s } + +// SanitizeTerminal strips ANSI escape sequences and C0/C1 control characters +// from a string, preventing terminal injection via untrusted data (e.g., +// WhatsApp chat names, message snippets). Preserves printable characters, +// tabs, and newlines. +// +// C1 control characters (U+0080–U+009F) are checked on the decoded rune, not +// the raw leading byte, so that UTF-8 encoded C1 chars (e.g., U+009B CSI +// encoded as 0xC2 0x9B) are correctly stripped. +func SanitizeTerminal(s string) string { + var b strings.Builder + b.Grow(len(s)) + i := 0 + for i < len(s) { + c := s[i] + // Strip ESC-initiated sequences (CSI, OSC, etc.). + if c == 0x1b && i+1 < len(s) { + next := s[i+1] + switch { + case next == '[': // CSI sequence: ESC [ ... + i += 2 + for i < len(s) && (s[i] < 0x40 || s[i] > 0x7E) { + i++ + } + if i < len(s) { + i++ // skip final byte + } + continue + case next == ']': // OSC sequence: ESC ] ... (ST or BEL) + i += 2 + for i < len(s) { + if s[i] == 0x07 { // BEL terminates OSC + i++ + break + } + if s[i] == 0x1b && i+1 < len(s) && s[i+1] == '\\' { // ST terminates OSC + i += 2 + break + } + i++ + } + continue + default: // Other ESC sequences (2-byte): skip both + i += 2 + continue + } + } + + // Decode the full rune so we can check C1 control characters that + // span multiple bytes in UTF-8 (e.g., U+009B is 0xC2 0x9B). + r, size := utf8.DecodeRuneInString(s[i:]) + if r == utf8.RuneError && size == 1 { + // Invalid UTF-8 byte — skip it. + i++ + continue + } + + // Allow tab, newline, carriage return; strip other C0/C1 control chars. + if r == '\t' || r == '\n' || r == '\r' { + b.WriteRune(r) + i += size + continue + } + if r < 0x20 || (r >= 0x7f && r <= 0x9f) { + i += size + continue + } + + b.WriteString(s[i : i+size]) + i += size + } + return b.String() +} diff --git a/internal/textutil/encoding_test.go b/internal/textutil/encoding_test.go index 96af17ed..6fc36e2f 100644 --- a/internal/textutil/encoding_test.go +++ b/internal/textutil/encoding_test.go @@ -496,3 +496,36 @@ func TestFirstLine(t *testing.T) { }) } } + +func TestSanitizeTerminal(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"plain text", "Hello World", "Hello World"}, + {"preserves tabs", "col1\tcol2", "col1\tcol2"}, + {"preserves newlines", "line1\nline2", "line1\nline2"}, + {"strips CSI color", "\x1b[31mred\x1b[0m", "red"}, + {"strips CSI cursor move", "\x1b[2Ahello", "hello"}, + {"strips OSC title (BEL)", "\x1b]0;evil title\x07safe", "safe"}, + {"strips OSC title (ST)", "\x1b]0;evil\x1b\\safe", "safe"}, + {"strips BEL", "\x07beep", "beep"}, + {"strips null bytes", "a\x00b", "ab"}, + {"strips C1 control byte", "a\x8fb", "ab"}, + {"strips UTF-8 encoded C1 CSI (U+009B)", "a\xc2\x9bb", "ab"}, + {"strips UTF-8 encoded C1 0x80-0x9F range", "a\xc2\x80z\xc2\x9fb", "azb"}, + {"preserves unicode", "café ☕ 日本語", "café ☕ 日本語"}, + {"strips embedded ESC seq", "before\x1b[1;32mgreen\x1b[0mafter", "beforegreenafter"}, + {"empty string", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeTerminal(tt.input) + if got != tt.want { + t.Errorf("SanitizeTerminal(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/tui/view.go b/internal/tui/view.go index 48df6208..bcf74d5a 100644 --- a/internal/tui/view.go +++ b/internal/tui/view.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/wesm/msgvault/internal/query" + "github.com/wesm/msgvault/internal/textutil" ) // Monochrome theme - adaptive for light and dark terminals @@ -527,16 +528,34 @@ func (m Model) messageListView() string { date := msg.SentAt.Format("2006-01-02 15:04") // Format from (rune-aware for international names) - from := msg.FromEmail + // Sanitize untrusted metadata to prevent terminal control-sequence injection. + from := textutil.SanitizeTerminal(msg.FromEmail) if msg.FromName != "" { - from = msg.FromName + from = textutil.SanitizeTerminal(msg.FromName) + } + // For chat messages: fall back to phone number, then group title + if from == "" && msg.FromPhone != "" { + from = textutil.SanitizeTerminal(msg.FromPhone) + } + if from == "" && msg.ConversationTitle != "" { + from = textutil.SanitizeTerminal(msg.ConversationTitle) } from = truncateRunes(from, fromWidth) from = fmt.Sprintf("%-*s", fromWidth, from) from = highlightTerms(from, m.searchQuery) // Format subject with indicators (rune-aware) - subject := msg.Subject + // For chat messages without a subject, show snippet or group title + subject := textutil.SanitizeTerminal(msg.Subject) + if subject == "" && msg.MessageType == "whatsapp" { + title := textutil.SanitizeTerminal(msg.ConversationTitle) + snippet := textutil.SanitizeTerminal(msg.Snippet) + if title != "" { + subject = title + ": " + snippet + } else { + subject = snippet + } + } if msg.DeletedAt != nil { subject = "🗑 " + subject // Deleted from server indicator } @@ -878,12 +897,19 @@ func (m Model) threadView() string { dateStr := msg.SentAt.Format("2006-01-02 15:04") // Format from/subject with deleted indicator - fromSubject := msg.FromEmail + // Sanitize untrusted metadata to prevent terminal control-sequence injection. + fromSubject := textutil.SanitizeTerminal(msg.FromEmail) if msg.FromName != "" { - fromSubject = msg.FromName + fromSubject = textutil.SanitizeTerminal(msg.FromName) + } + // For chat messages: fall back to phone number + if fromSubject == "" && msg.FromPhone != "" { + fromSubject = textutil.SanitizeTerminal(msg.FromPhone) } if msg.Subject != "" { - fromSubject = truncateRunes(fromSubject, 18) + ": " + msg.Subject + fromSubject = truncateRunes(fromSubject, 18) + ": " + textutil.SanitizeTerminal(msg.Subject) + } else if msg.MessageType == "whatsapp" && msg.Snippet != "" { + fromSubject = truncateRunes(fromSubject, 18) + ": " + textutil.SanitizeTerminal(msg.Snippet) } if msg.DeletedAt != nil { fromSubject = "🗑 " + fromSubject // Deleted from server indicator diff --git a/internal/whatsapp/contacts.go b/internal/whatsapp/contacts.go new file mode 100644 index 00000000..d030260c --- /dev/null +++ b/internal/whatsapp/contacts.go @@ -0,0 +1,243 @@ +package whatsapp + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + + "github.com/wesm/msgvault/internal/store" +) + +// vcardContact represents a parsed contact from a vCard file. +type vcardContact struct { + FullName string + Phones []string // normalized to E.164 +} + +// ImportContacts reads a .vcf file and updates participant display names +// for any phone numbers that match existing participants in the store. +// Only updates existing participants — does not create new ones. +// Returns the number of existing participants whose names were updated. +func ImportContacts(s *store.Store, vcfPath string) (matched, total int, err error) { + contacts, err := parseVCardFile(vcfPath) + if err != nil { + return 0, 0, fmt.Errorf("parse vcard: %w", err) + } + + total = len(contacts) + var errCount int + for _, c := range contacts { + if c.FullName == "" { + continue + } + for _, phone := range c.Phones { + if phone == "" { + continue + } + // Only update display_name for participants that already exist. + // Does not create new participants — those are created during message import. + updated, updateErr := s.UpdateParticipantDisplayNameByPhone(phone, c.FullName) + if updateErr != nil { + errCount++ + continue + } + if updated { + matched++ + } + } + } + + if errCount > 0 { + return matched, total, fmt.Errorf("contact import completed with %d database errors", errCount) + } + + return matched, total, nil +} + +// parseVCardFile reads a .vcf file and returns parsed contacts. +// Handles vCard 2.1 and 3.0 formats, including RFC 2425 line folding +// and QUOTED-PRINTABLE encoded values. +func parseVCardFile(path string) ([]vcardContact, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + // Read all lines and unfold continuation lines (RFC 2425: lines starting + // with a space or tab are continuations of the previous line). + var rawLines []string + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + for scanner.Scan() { + line := scanner.Text() + if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') { + // Continuation line — append to previous. + if len(rawLines) > 0 { + rawLines[len(rawLines)-1] += strings.TrimLeft(line, " \t") + continue + } + } + rawLines = append(rawLines, line) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan vcard: %w", err) + } + + // Handle QUOTED-PRINTABLE soft line breaks: a trailing '=' means the + // value continues on the next line (vCard 2.1 convention). The scanner + // already consumed the newline, so we rejoin here. + var qpJoined []string + for i := 0; i < len(rawLines); i++ { + line := rawLines[i] + for strings.HasSuffix(line, "=") && i+1 < len(rawLines) { + line = line[:len(line)-1] + rawLines[i+1] + i++ + } + qpJoined = append(qpJoined, line) + } + rawLines = qpJoined + + var contacts []vcardContact + var current *vcardContact + + for _, line := range rawLines { + line = strings.TrimSpace(line) + // vCard field names are case-insensitive (RFC 2426). + // Uppercase the key portion for matching, but preserve original value bytes. + upper := strings.ToUpper(line) + + switch { + case upper == "BEGIN:VCARD": + current = &vcardContact{} + + case upper == "END:VCARD": + if current != nil && (current.FullName != "" || len(current.Phones) > 0) { + contacts = append(contacts, *current) + } + current = nil + + case current == nil: + continue + + case strings.HasPrefix(upper, "FN:") || strings.HasPrefix(upper, "FN;"): + // FN (formatted name) — preferred over N because it's the display name. + name := extractVCardValue(line) + if isQuotedPrintable(line) { + name = decodeQuotedPrintable(name) + } + if name != "" { + current.FullName = name + } + + case strings.HasPrefix(upper, "TEL"): + // TEL;CELL:+447... or TEL;TYPE=CELL:+447... or TEL:+447... + raw := extractVCardValue(line) + phone := normalizeVCardPhone(raw) + if phone != "" { + current.Phones = append(current.Phones, phone) + } + } + } + + return contacts, nil +} + +// extractVCardValue extracts the value part from a vCard line. +// Handles both "KEY:value" and "KEY;params:value" formats. +func extractVCardValue(line string) string { + // Find the first colon that separates key from value. + idx := strings.Index(line, ":") + if idx < 0 { + return "" + } + return strings.TrimSpace(line[idx+1:]) +} + +// isQuotedPrintable returns true if a vCard line indicates QUOTED-PRINTABLE encoding. +func isQuotedPrintable(line string) bool { + upper := strings.ToUpper(line) + return strings.Contains(upper, "ENCODING=QUOTED-PRINTABLE") || + strings.Contains(upper, ";QUOTED-PRINTABLE") +} + +// decodeQuotedPrintable decodes a QUOTED-PRINTABLE encoded string. +// Handles =XX hex sequences (e.g., =C3=A9 → é). +func decodeQuotedPrintable(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + if s[i] == '=' && i+2 < len(s) { + hi := unhex(s[i+1]) + lo := unhex(s[i+2]) + if hi >= 0 && lo >= 0 { + b.WriteByte(byte(hi<<4 | lo)) + i += 2 + continue + } + } + b.WriteByte(s[i]) + } + return b.String() +} + +// unhex returns the numeric value of a hex digit, or -1 if invalid. +func unhex(c byte) int { + switch { + case c >= '0' && c <= '9': + return int(c - '0') + case c >= 'A' && c <= 'F': + return int(c - 'A' + 10) + case c >= 'a' && c <= 'f': + return int(c - 'a' + 10) + default: + return -1 + } +} + +// nonDigitRe matches any non-digit character. +var nonDigitRe = regexp.MustCompile(`[^\d]`) + +// normalizeVCardPhone normalizes a phone number from a vCard to E.164 format. +// Handles various formats: +447..., 003-362-..., 077-380-06043, etc. +func normalizeVCardPhone(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + + // Check if it starts with + (already has country code). + hasPlus := strings.HasPrefix(raw, "+") + + // Strip trunk prefix "(0)" before digit extraction. + // Common in UK/European numbers: +44 (0)7700 means +447700, not +4407700. + raw = strings.ReplaceAll(raw, "(0)", "") + + // Strip everything except digits. + digits := nonDigitRe.ReplaceAllString(raw, "") + if digits == "" { + return "" + } + + // If originally had +, it's already E.164-ish. + if hasPlus { + return "+" + digits + } + + // Handle 00-prefixed international format (e.g., 003-362-4921221 → +33624921221). + if strings.HasPrefix(digits, "00") && len(digits) > 4 { + return "+" + digits[2:] + } + + // Local numbers starting with 0 (e.g., 07738006043) are country-specific + // and cannot be reliably normalized without knowing the country code. + // Skip these rather than hardcoding a country assumption. + + // Without an explicit country code indicator (+ or 00), we cannot + // reliably determine the country code. Skip ambiguous numbers rather + // than guessing — a wrong prefix would match the wrong participant. + return "" +} diff --git a/internal/whatsapp/contacts_test.go b/internal/whatsapp/contacts_test.go new file mode 100644 index 00000000..584e60e6 --- /dev/null +++ b/internal/whatsapp/contacts_test.go @@ -0,0 +1,241 @@ +package whatsapp + +import ( + "os" + "path/filepath" + "testing" +) + +func TestNormalizeVCardPhone(t *testing.T) { + tests := []struct { + raw string + want string + }{ + // Already E.164 + {"+447700900000", "+447700900000"}, + {"+12025551234", "+12025551234"}, + {"+33624921221", "+33624921221"}, + + // With dashes/spaces + {"+44 7700 900000", "+447700900000"}, + {"+1-202-555-1234", "+12025551234"}, + + // Trunk prefix (0) — common in UK/European numbers + {"+44 (0)7700 900000", "+447700900000"}, + {"+44(0)20 7123 4567", "+442071234567"}, + + // 00 prefix (international) + {"003-362-4921221", "+33624921221"}, + {"0033624921221", "+33624921221"}, + {"004-479-35975580", "+447935975580"}, + + // 0 prefix (local) — skipped, country-ambiguous + {"011-585-73843", ""}, + {"07738006043", ""}, + {"077-380-06043", ""}, + + // No explicit country code indicator — ambiguous, skip + {"447700900000", ""}, + {"2025551234", ""}, + + // Empty/invalid + {"", ""}, + {" ", ""}, + {"abc", ""}, + {"12", ""}, // too short + } + + for _, tt := range tests { + got := normalizeVCardPhone(tt.raw) + if got != tt.want { + t.Errorf("normalizeVCardPhone(%q) = %q, want %q", tt.raw, got, tt.want) + } + } +} + +func TestParseVCardFile(t *testing.T) { + // Write a test vCard file. + vcf := `BEGIN:VCARD +VERSION:2.1 +N:McGregor;Alastair;;; +FN:Alastair McGregor +TEL;CELL:+447984959428 +END:VCARD +BEGIN:VCARD +VERSION:2.1 +N:France;Geoff;;; +FN:Geoff France +TEL;X-Mobile:+33562645735 +END:VCARD +BEGIN:VCARD +VERSION:2.1 +N:Studios;Claire Mohacek -;Amazon;; +FN:Claire Mohacek - Amazon Studios +TEL;CELL:077-380-06043 +END:VCARD +BEGIN:VCARD +VERSION:2.1 +TEL;CELL: +END:VCARD +BEGIN:VCARD +VERSION:3.0 +FN:Multi Phone Person +TEL;TYPE=CELL:+447700900001 +TEL;TYPE=WORK:+442071234567 +END:VCARD +` + dir := t.TempDir() + path := filepath.Join(dir, "test.vcf") + if err := os.WriteFile(path, []byte(vcf), 0644); err != nil { + t.Fatal(err) + } + + contacts, err := parseVCardFile(path) + if err != nil { + t.Fatalf("parseVCardFile() error: %v", err) + } + + if len(contacts) != 4 { // 4 with names/phones, 1 empty entry skipped + t.Fatalf("got %d contacts, want 4", len(contacts)) + } + + // First contact + if contacts[0].FullName != "Alastair McGregor" { + t.Errorf("contact 0 name = %q, want %q", contacts[0].FullName, "Alastair McGregor") + } + if len(contacts[0].Phones) != 1 || contacts[0].Phones[0] != "+447984959428" { + t.Errorf("contact 0 phones = %v, want [+447984959428]", contacts[0].Phones) + } + + // Third contact — local number (0-prefix) is now skipped (country-ambiguous) + if contacts[2].FullName != "Claire Mohacek - Amazon Studios" { + t.Errorf("contact 2 name = %q", contacts[2].FullName) + } + if len(contacts[2].Phones) != 0 { + t.Errorf("contact 2 phones = %v, want [] (local numbers skipped)", contacts[2].Phones) + } + + // Multi phone contact + if contacts[3].FullName != "Multi Phone Person" { + t.Errorf("contact 3 name = %q", contacts[3].FullName) + } + if len(contacts[3].Phones) != 2 { + t.Errorf("contact 3 phone count = %d, want 2", len(contacts[3].Phones)) + } +} + +func TestParseVCardFile_FoldedAndEncoded(t *testing.T) { + // Test RFC 2425 line folding and QUOTED-PRINTABLE encoding. + vcf := "BEGIN:VCARD\r\n" + + "VERSION:2.1\r\n" + + "FN:José\r\n" + + " García\r\n" + // folded continuation line + "TEL;CELL:+34\r\n" + + " 612345678\r\n" + // folded phone + "END:VCARD\r\n" + + "BEGIN:VCARD\r\n" + + "VERSION:2.1\r\n" + + "FN;ENCODING=QUOTED-PRINTABLE:Ren=C3=A9 Dupont\r\n" + + "TEL;CELL:+33612345678\r\n" + + "END:VCARD\r\n" + + dir := t.TempDir() + path := filepath.Join(dir, "folded.vcf") + if err := os.WriteFile(path, []byte(vcf), 0644); err != nil { + t.Fatal(err) + } + + contacts, err := parseVCardFile(path) + if err != nil { + t.Fatalf("parseVCardFile() error: %v", err) + } + + if len(contacts) != 2 { + t.Fatalf("got %d contacts, want 2", len(contacts)) + } + + // Folded name (RFC 2425: leading whitespace is stripped, content concatenated) + if contacts[0].FullName != "JoséGarcía" { + t.Errorf("folded name = %q, want %q", contacts[0].FullName, "JoséGarcía") + } + if len(contacts[0].Phones) != 1 || contacts[0].Phones[0] != "+34612345678" { + t.Errorf("folded phone = %v, want [+34612345678]", contacts[0].Phones) + } + + // QUOTED-PRINTABLE encoded name + if contacts[1].FullName != "René Dupont" { + t.Errorf("QP name = %q, want %q", contacts[1].FullName, "René Dupont") + } +} + +func TestParseVCardFile_QPSoftBreaks(t *testing.T) { + // Test QUOTED-PRINTABLE soft line breaks (= at end of line). + // vCard 2.1 uses = at EOL to wrap long QP values across lines. + vcf := "BEGIN:VCARD\r\n" + + "VERSION:2.1\r\n" + + "FN;ENCODING=QUOTED-PRINTABLE:Jo=C3=A3o da =\r\n" + + "Silva\r\n" + + "TEL;CELL:+5511999887766\r\n" + + "END:VCARD\r\n" + + dir := t.TempDir() + path := filepath.Join(dir, "qp-soft.vcf") + if err := os.WriteFile(path, []byte(vcf), 0644); err != nil { + t.Fatal(err) + } + + contacts, err := parseVCardFile(path) + if err != nil { + t.Fatalf("parseVCardFile() error: %v", err) + } + + if len(contacts) != 1 { + t.Fatalf("got %d contacts, want 1", len(contacts)) + } + + // Soft break should be stripped, continuation joined, then QP decoded. + want := "João da Silva" + if contacts[0].FullName != want { + t.Errorf("QP soft break name = %q, want %q", contacts[0].FullName, want) + } +} + +func TestDecodeQuotedPrintable(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"hello", "hello"}, + {"Ren=C3=A9", "René"}, + {"=C3=A9=C3=A8", "éè"}, + {"no=encoding", "no=encoding"}, // invalid hex after = — kept as-is + {"end=", "end="}, // trailing = — kept as-is + } + for _, tt := range tests { + got := decodeQuotedPrintable(tt.input) + if got != tt.want { + t.Errorf("decodeQuotedPrintable(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestExtractVCardValue(t *testing.T) { + tests := []struct { + line string + want string + }{ + {"FN:John Doe", "John Doe"}, + {"FN;CHARSET=UTF-8:John Doe", "John Doe"}, + {"TEL;CELL:+447700900000", "+447700900000"}, + {"TEL;TYPE=CELL:+447700900000", "+447700900000"}, + {"TEL:+447700900000", "+447700900000"}, + {"NO_COLON", ""}, + } + + for _, tt := range tests { + got := extractVCardValue(tt.line) + if got != tt.want { + t.Errorf("extractVCardValue(%q) = %q, want %q", tt.line, got, tt.want) + } + } +} diff --git a/internal/whatsapp/importer.go b/internal/whatsapp/importer.go new file mode 100644 index 00000000..d12795c8 --- /dev/null +++ b/internal/whatsapp/importer.go @@ -0,0 +1,720 @@ +package whatsapp + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/wesm/msgvault/internal/store" +) + +// Importer handles importing WhatsApp messages from a decrypted msgstore.db +// into the msgvault store. +type Importer struct { + store *store.Store + progress ImportProgress +} + +// NewImporter creates a new WhatsApp importer. +func NewImporter(s *store.Store, progress ImportProgress) *Importer { + if progress == nil { + progress = NullProgress{} + } + return &Importer{ + store: s, + progress: progress, + } +} + +// Import performs the full WhatsApp import from a decrypted msgstore.db. +func (imp *Importer) Import(ctx context.Context, waDBPath string, opts ImportOptions) (*ImportSummary, error) { + startTime := time.Now() + summary := &ImportSummary{} + + // Open WhatsApp DB read-only. + // Use file: URI to safely handle paths containing '?' or other special characters. + dsn := (&url.URL{ + Scheme: "file", + OmitHost: true, + Path: waDBPath, + RawQuery: "mode=ro&_journal_mode=WAL&_busy_timeout=5000", + }).String() + waDB, err := sql.Open("sqlite3", dsn) + if err != nil { + return nil, fmt.Errorf("open whatsapp db: %w", err) + } + defer waDB.Close() + + // Verify it's a valid WhatsApp DB. + if err := verifyWhatsAppDB(waDB); err != nil { + return nil, err + } + + // Create or get the WhatsApp source. + source, err := imp.store.GetOrCreateSource("whatsapp", opts.Phone) + if err != nil { + return nil, fmt.Errorf("get or create source: %w", err) + } + + if opts.DisplayName != "" { + _ = imp.store.UpdateSourceDisplayName(source.ID, opts.DisplayName) + } + + // Start a sync run for tracking. + syncID, err := imp.store.StartSync(source.ID, "whatsapp_import") + if err != nil { + return nil, fmt.Errorf("start sync: %w", err) + } + + // Ensure we complete/fail the sync run on exit. + var syncErr error + defer func() { + if syncErr != nil { + _ = imp.store.FailSync(syncID, syncErr.Error()) + } else { + _ = imp.store.CompleteSync(syncID, "") + } + }() + + imp.progress.OnStart() + + // Create participant for the phone owner (self). + selfParticipantID, err := imp.store.EnsureParticipantByPhone(opts.Phone, opts.DisplayName) + if err != nil { + syncErr = err + return nil, fmt.Errorf("ensure self participant: %w", err) + } + summary.Participants++ + + // Fetch all chats from WhatsApp DB. + chats, err := fetchChats(waDB) + if err != nil { + syncErr = err + return nil, fmt.Errorf("fetch chats: %w", err) + } + + // Load lid → phone mapping for resolving "lid" senders. + lidMap, err := fetchLidMap(waDB) + if err != nil { + syncErr = err + return nil, fmt.Errorf("fetch lid map: %w", err) + } + + batchSize := opts.BatchSize + if batchSize <= 0 { + batchSize = 1000 + } + + // Track key_id → message_id for reply threading within each chat. + // Scoped per chat to bound memory; cross-chat quotes won't thread + // but that's rare and the quoted text is still in the message body. + keyIDToMsgID := make(map[string]int64) + + totalLimit := opts.Limit + totalAdded := int64(0) + + for _, chat := range chats { + // Clear reply map per chat to prevent unbounded growth. + clear(keyIDToMsgID) + if ctx.Err() != nil { + syncErr = ctx.Err() + return nil, ctx.Err() + } + + // Check global limit. + if totalLimit > 0 && totalAdded >= int64(totalLimit) { + break + } + + summary.ChatsProcessed++ + + // Map chat to conversation. + sourceConvID, convType, title := mapConversation(chat) + conversationID, err := imp.store.EnsureConversationWithType(source.ID, sourceConvID, convType, title) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("ensure conversation %s: %w", sourceConvID, err)) + continue + } + + imp.progress.OnChatStart(chat.RawString, chatTitle(chat), 0) + + // For direct chats: add the remote participant. + if !isGroupChat(chat) && chat.User != "" { + phone := normalizePhone(chat.User, chat.Server) + if phone == "" { + // Non-phone JID (e.g., lid:..., broadcast) — skip. + } else if participantID, err := imp.store.EnsureParticipantByPhone(phone, ""); err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("ensure participant %s: %w", phone, err)) + } else { + summary.Participants++ + _ = imp.store.EnsureConversationParticipant(conversationID, participantID, "member") + _ = imp.store.EnsureConversationParticipant(conversationID, selfParticipantID, "member") + } + } + + // For group chats: add all group participants. + if isGroupChat(chat) { + members, err := fetchGroupParticipants(waDB, chat.RawString) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("fetch group participants for %s: %w", sourceConvID, err)) + } else { + for _, member := range members { + phone := normalizePhone(member.MemberUser, member.MemberServer) + if phone == "" { + continue // Non-phone JID — skip. + } + participantID, err := imp.store.EnsureParticipantByPhone(phone, "") + if err != nil { + summary.Errors++ + continue + } + summary.Participants++ + role := mapGroupRole(member.Admin) + _ = imp.store.EnsureConversationParticipant(conversationID, participantID, role) + } + } + } + + // Track resolved sender participant IDs for participant fallback. + // After the message loop, we ensure each sender is a conversation + // participant — covers groups where group_participants is empty. + chatSenderIDs := make(map[int64]struct{}) + + // Process messages in batches. + chatAdded := int64(0) + afterID := int64(0) + + for { + if ctx.Err() != nil { + syncErr = ctx.Err() + return nil, ctx.Err() + } + + // Check global limit for this batch. + remaining := batchSize + if totalLimit > 0 { + left := int64(totalLimit) - totalAdded + if left <= 0 { + break + } + if left < int64(remaining) { + remaining = int(left) + } + } + + messages, err := fetchMessages(waDB, chat.RowID, afterID, remaining) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("fetch messages for chat %s after %d: %w", sourceConvID, afterID, err)) + break + } + + if len(messages) == 0 { + break + } + + // Collect message row IDs for batch media/reaction/quote lookups. + msgRowIDs := make([]int64, len(messages)) + for i, m := range messages { + msgRowIDs[i] = m.RowID + } + + // Batch-fetch media, reactions, and quotes. + mediaMap, err := fetchMedia(waDB, msgRowIDs) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("fetch media: %w", err)) + mediaMap = make(map[int64]waMedia) + } + + reactionMap, err := fetchReactions(waDB, msgRowIDs) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("fetch reactions: %w", err)) + reactionMap = make(map[int64][]waReaction) + } + + quotedMap, err := fetchQuotedMessages(waDB, msgRowIDs) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("fetch quoted messages: %w", err)) + quotedMap = make(map[int64]waQuoted) + } + + for _, waMsg := range messages { + summary.MessagesProcessed++ + afterID = waMsg.RowID + + // Skip system messages and calls. + if isSkippedType(waMsg.MessageType) { + summary.MessagesSkipped++ + continue + } + + // Skip messages with empty key_id — they can't be uniquely + // identified for upsert and would collide with each other. + if waMsg.KeyID == "" { + summary.MessagesSkipped++ + continue + } + + // Resolve sender. + var senderID sql.NullInt64 + if waMsg.FromMe == 1 { + senderID = sql.NullInt64{Int64: selfParticipantID, Valid: true} + } else if waMsg.SenderServer.Valid && waMsg.SenderServer.String == "lid" { + // Lid JID — resolve via jid_map before trying normalizePhone, + // because lid user strings can be 15 digits and pass E.164 + // validation despite not being real phone numbers. + phone := resolveLidSender(waMsg.SenderJIDRowID, waMsg.SenderServer.String, lidMap) + if phone != "" { + pid, err := imp.store.EnsureParticipantByPhone(phone, "") + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("ensure sender participant %s: %w", phone, err)) + } else { + senderID = sql.NullInt64{Int64: pid, Valid: true} + } + } + } else if waMsg.SenderUser.Valid && waMsg.SenderUser.String != "" { + phone := normalizePhone(waMsg.SenderUser.String, waMsg.SenderServer.String) + if phone != "" { + pid, err := imp.store.EnsureParticipantByPhone(phone, "") + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("ensure sender participant %s: %w", phone, err)) + } else { + senderID = sql.NullInt64{Int64: pid, Valid: true} + } + } + } else if !isGroupChat(chat) && waMsg.FromMe == 0 { + // In a direct chat, the other person is the sender. + phone := normalizePhone(chat.User, chat.Server) + if phone != "" { + pid, err := imp.store.EnsureParticipantByPhone(phone, "") + if err == nil { + senderID = sql.NullInt64{Int64: pid, Valid: true} + } + } + } + + // Track sender for participant fallback. + if senderID.Valid { + chatSenderIDs[senderID.Int64] = struct{}{} + } + + // Build and upsert the message. + msg := mapMessage(waMsg, conversationID, source.ID, senderID) + messageID, err := imp.store.UpsertMessage(&msg) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("upsert message %s: %w", waMsg.KeyID, err)) + continue + } + + // Track for reply threading. + keyIDToMsgID[waMsg.KeyID] = messageID + + summary.MessagesAdded++ + chatAdded++ + totalAdded++ + + // Store message body. + bodyText := sql.NullString{} + if waMsg.TextData.Valid && waMsg.TextData.String != "" { + bodyText = waMsg.TextData + } + // Check for media caption as additional body text. + if media, ok := mediaMap[waMsg.RowID]; ok { + if media.MediaCaption.Valid && media.MediaCaption.String != "" { + if bodyText.Valid && bodyText.String != "" { + // Append caption to body. + bodyText.String += "\n\n" + media.MediaCaption.String + } else { + bodyText = media.MediaCaption + } + } + } + if bodyText.Valid { + _ = imp.store.UpsertMessageBody(messageID, bodyText, sql.NullString{}) + } + + // Store raw JSON for re-parsing. + rawJSON, err := json.Marshal(waMsg) + if err == nil { + _ = imp.store.UpsertMessageRawWithFormat(messageID, rawJSON, "whatsapp_json") + } + + // Handle media/attachments. + if media, ok := mediaMap[waMsg.RowID]; ok { + summary.AttachmentsFound++ + mediaType := mapMediaType(waMsg.MessageType) + + storagePath, contentHash := imp.handleMediaFile(media, opts) + if storagePath != "" { + summary.MediaCopied++ + } + + mimeType := "" + if media.MimeType.Valid { + mimeType = media.MimeType.String + } + + filename := "" + if media.FilePath.Valid { + filename = filepath.Base(media.FilePath.String) + } + + size := 0 + if media.FileSize.Valid { + size = int(media.FileSize.Int64) + } + + // Use UpsertAttachment — it handles dedup by content_hash. + err := imp.store.UpsertAttachment(messageID, filename, mimeType, storagePath, contentHash, size) + if err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("upsert attachment for message %s: %w", waMsg.KeyID, err)) + } + + // Store media metadata in the attachments table is done above. + // For extra metadata (width, height, duration, media_type), + // update via a direct SQL call since UpsertAttachment doesn't have those fields. + if mediaType != "" || (media.Width.Valid && media.Width.Int64 > 0) { + imp.updateAttachmentMetadata(messageID, contentHash, mediaType, media) + } + } + + // Handle quoted/reply messages. + if quoted, ok := quotedMap[waMsg.RowID]; ok { + if replyToMsgID, found := keyIDToMsgID[quoted.QuotedKeyID]; found { + imp.setReplyTo(messageID, replyToMsgID) + } else if dbMsgID, lookupErr := imp.lookupMessageByKeyID(source.ID, quoted.QuotedKeyID); lookupErr == nil && dbMsgID > 0 { + // Found in DB from a previous import run or another chat. + imp.setReplyTo(messageID, dbMsgID) + } + } + + // Handle reactions. + if reactions, ok := reactionMap[waMsg.RowID]; ok { + for _, r := range reactions { + reactionType, reactionValue := mapReaction(r) + if reactionValue == "" { + continue + } + + var reactorID int64 + if r.SenderServer.Valid && r.SenderServer.String == "lid" { + // Lid JID — resolve via jid_map first. + phone := resolveLidSender(r.SenderJIDRowID, r.SenderServer.String, lidMap) + if phone == "" { + continue + } + pid, err := imp.store.EnsureParticipantByPhone(phone, "") + if err != nil { + summary.Errors++ + continue + } + reactorID = pid + } else if r.SenderUser.Valid && r.SenderUser.String != "" { + phone := normalizePhone(r.SenderUser.String, r.SenderServer.String) + if phone == "" { + continue // Non-phone JID — skip reaction. + } + pid, err := imp.store.EnsureParticipantByPhone(phone, "") + if err != nil { + summary.Errors++ + continue + } + reactorID = pid + } else { + // Self reaction. + reactorID = selfParticipantID + } + + createdAt := time.Unix(r.Timestamp/1000, 0) + if err := imp.store.UpsertReaction(messageID, reactorID, reactionType, reactionValue, createdAt); err != nil { + summary.Errors++ + imp.progress.OnError(fmt.Errorf("upsert reaction: %w", err)) + } else { + summary.ReactionsAdded++ + } + } + } + + // FTS indexing. + if bodyText.Valid { + senderAddr := "" + if waMsg.FromMe == 1 { + senderAddr = opts.Phone + } else if waMsg.SenderServer.Valid && waMsg.SenderServer.String == "lid" { + senderAddr = resolveLidSender(waMsg.SenderJIDRowID, waMsg.SenderServer.String, lidMap) + } else if waMsg.SenderUser.Valid { + senderAddr = normalizePhone(waMsg.SenderUser.String, waMsg.SenderServer.String) + } + _ = imp.store.UpsertFTS(messageID, "", bodyText.String, senderAddr, "", "") + } + } + + // Update sync run progress counters (for monitoring, not resume). + // Resume is not implemented yet — re-running is safe due to upsert dedup. + _ = imp.store.UpdateSyncCheckpoint(syncID, &store.Checkpoint{ + MessagesProcessed: summary.MessagesProcessed, + MessagesAdded: summary.MessagesAdded, + }) + + imp.progress.OnProgress(summary.MessagesProcessed, summary.MessagesAdded, summary.MessagesSkipped) + + // If we got fewer than requested, we've finished this chat. + if len(messages) < remaining { + break + } + } + + // Participant fallback: ensure every resolved sender is a conversation + // participant. Covers groups where group_participants is empty (newer + // WhatsApp versions) and any senders discovered via lid resolution. + for pid := range chatSenderIDs { + _ = imp.store.EnsureConversationParticipant(conversationID, pid, "member") + } + // Always include self as participant. + _ = imp.store.EnsureConversationParticipant(conversationID, selfParticipantID, "member") + + imp.progress.OnChatComplete(chat.RawString, chatAdded) + } + + // Update denormalised conversation counts for the WhatsApp source. + _, _ = imp.store.DB().Exec(` + UPDATE conversations SET + message_count = ( + SELECT COUNT(*) FROM messages + WHERE conversation_id = conversations.id + ), + participant_count = ( + SELECT COUNT(*) FROM conversation_participants + WHERE conversation_id = conversations.id + ), + last_message_at = ( + SELECT MAX(COALESCE(sent_at, received_at, internal_date)) + FROM messages + WHERE conversation_id = conversations.id + ) + WHERE source_id = ? + `, source.ID) + + summary.Duration = time.Since(startTime) + imp.progress.OnComplete(summary) + + return summary, nil +} + +// handleMediaFile attempts to find and copy a media file to content-addressed storage. +// Returns (storagePath, contentHash). Both empty if file not found. +func (imp *Importer) handleMediaFile(media waMedia, opts ImportOptions) (string, string) { + if opts.MediaDir == "" || opts.AttachmentsDir == "" || !media.FilePath.Valid || media.FilePath.String == "" { + return "", "" + } + + mediaDir := opts.MediaDir + + // Sanitize the path from the WhatsApp DB (untrusted data). + relPath := filepath.Clean(media.FilePath.String) + + // Reject absolute paths — the DB should only contain relative paths. + if filepath.IsAbs(relPath) { + relPath = filepath.Base(relPath) + } + + // Reject directory traversal. + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) { + relPath = filepath.Base(relPath) + } + + // Build candidate path and verify it stays within mediaDir. + fullPath := filepath.Join(mediaDir, relPath) + absMediaDir, err := filepath.Abs(mediaDir) + if err != nil { + return "", "" + } + absFullPath, err := filepath.Abs(fullPath) + if err != nil { + return "", "" + } + if !strings.HasPrefix(absFullPath, absMediaDir+string(filepath.Separator)) && absFullPath != absMediaDir { + // Path escapes mediaDir — fall back to base filename only. + fullPath = filepath.Join(mediaDir, filepath.Base(relPath)) + absFullPath, _ = filepath.Abs(fullPath) + if !strings.HasPrefix(absFullPath, absMediaDir+string(filepath.Separator)) { + return "", "" + } + } + + // Check file exists. + info, err := os.Stat(fullPath) + if err != nil { + // Try just the filename as fallback. + fullPath = filepath.Join(mediaDir, filepath.Base(relPath)) + info, err = os.Stat(fullPath) + if err != nil { + return "", "" + } + } + + // Enforce max file size to prevent OOM. + maxSize := opts.MaxMediaFileSize + if maxSize <= 0 { + maxSize = 100 * 1024 * 1024 // 100MB default + } + if info.Size() > maxSize { + return "", "" + } + + // Open file and compute hash by streaming (no full-file read into memory). + f, err := os.Open(fullPath) + if err != nil { + return "", "" + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, io.LimitReader(f, maxSize+1)); err != nil { + return "", "" + } + contentHash := fmt.Sprintf("%x", h.Sum(nil)) + + // Content-addressed storage: // + // The storage_path stored in DB is the relative portion: / + relStoragePath := filepath.Join(contentHash[:2], contentHash) + absStoragePath := filepath.Join(opts.AttachmentsDir, relStoragePath) + + // Check for dedup — file already stored. + if _, err := os.Stat(absStoragePath); err == nil { + return relStoragePath, contentHash + } + + // Create directory and stream-copy the file. + absStorageDir := filepath.Dir(absStoragePath) + if err := os.MkdirAll(absStorageDir, 0750); err != nil { + return "", contentHash + } + + // Seek back to beginning of source file for the copy. + if _, err := f.Seek(0, io.SeekStart); err != nil { + return "", contentHash + } + + dst, err := os.OpenFile(absStoragePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) + if err != nil { + if os.IsExist(err) { + // Race: another goroutine already wrote it. + return relStoragePath, contentHash + } + return "", contentHash + } + + if _, err := io.Copy(dst, io.LimitReader(f, maxSize)); err != nil { + dst.Close() + os.Remove(absStoragePath) + return "", contentHash + } + if err := dst.Close(); err != nil { + os.Remove(absStoragePath) + return "", contentHash + } + + return relStoragePath, contentHash +} + +// updateAttachmentMetadata updates media-specific metadata on an attachment record. +func (imp *Importer) updateAttachmentMetadata(messageID int64, contentHash, mediaType string, media waMedia) { + var width, height, durationMS sql.NullInt64 + if media.Width.Valid && media.Width.Int64 > 0 { + width = media.Width + } + if media.Height.Valid && media.Height.Int64 > 0 { + height = media.Height + } + if media.MediaDuration.Valid && media.MediaDuration.Int64 > 0 { + // WhatsApp stores duration in seconds; msgvault uses milliseconds. + durationMS = sql.NullInt64{Int64: media.MediaDuration.Int64 * 1000, Valid: true} + } + + _, _ = imp.store.DB().Exec(` + UPDATE attachments SET media_type = ?, width = ?, height = ?, duration_ms = ? + WHERE message_id = ? AND (content_hash = ? OR content_hash IS NULL) + `, mediaType, width, height, durationMS, messageID, contentHash) +} + +// lookupMessageByKeyID looks up a previously imported message by its WhatsApp key_id. +// Returns 0 if not found. +func (imp *Importer) lookupMessageByKeyID(sourceID int64, keyID string) (int64, error) { + var msgID int64 + err := imp.store.DB().QueryRow( + `SELECT id FROM messages WHERE source_id = ? AND source_message_id = ?`, + sourceID, keyID, + ).Scan(&msgID) + if err == sql.ErrNoRows { + return 0, nil + } + return msgID, err +} + +// setReplyTo sets the reply_to_message_id on a message. +func (imp *Importer) setReplyTo(messageID, replyToID int64) { + _, _ = imp.store.DB().Exec(` + UPDATE messages SET reply_to_message_id = ? WHERE id = ? + `, replyToID, messageID) +} + +// verifyWhatsAppDB checks that the database looks like a WhatsApp msgstore.db. +func verifyWhatsAppDB(db *sql.DB) error { + // Check for the 'message' table with expected columns. + var count int + err := db.QueryRow(` + SELECT COUNT(*) FROM sqlite_master + WHERE type = 'table' AND name = 'message' + `).Scan(&count) + if err != nil { + return fmt.Errorf("check whatsapp db: %w", err) + } + if count == 0 { + return fmt.Errorf("not a valid WhatsApp database: 'message' table not found") + } + + // Check for the 'jid' table. + err = db.QueryRow(` + SELECT COUNT(*) FROM sqlite_master + WHERE type = 'table' AND name = 'jid' + `).Scan(&count) + if err != nil { + return fmt.Errorf("check whatsapp db: %w", err) + } + if count == 0 { + return fmt.Errorf("not a valid WhatsApp database: 'jid' table not found") + } + + // Check for the 'chat' table. + err = db.QueryRow(` + SELECT COUNT(*) FROM sqlite_master + WHERE type = 'table' AND name = 'chat' + `).Scan(&count) + if err != nil { + return fmt.Errorf("check whatsapp db: %w", err) + } + if count == 0 { + return fmt.Errorf("not a valid WhatsApp database: 'chat' table not found") + } + + return nil +} diff --git a/internal/whatsapp/mapping.go b/internal/whatsapp/mapping.go new file mode 100644 index 00000000..aa4e195b --- /dev/null +++ b/internal/whatsapp/mapping.go @@ -0,0 +1,215 @@ +package whatsapp + +import ( + "database/sql" + "strings" + "time" + "unicode/utf8" + + "github.com/wesm/msgvault/internal/store" +) + +// isGroupChat returns true if the chat represents a group conversation. +// A chat is a group if group_type > 0 OR if the JID server is "g.us". +// Some groups (e.g. WhatsApp Communities and their sub-groups) have +// group_type = 0 despite being groups; the JID server is the +// definitive signal. +func isGroupChat(chat waChat) bool { + return chat.GroupType > 0 || chat.Server == "g.us" +} + +// mapConversation maps a WhatsApp chat to a msgvault conversation. +// Returns the source_conversation_id, conversation_type, and title. +func mapConversation(chat waChat) (sourceConvID, convType, title string) { + sourceConvID = chat.RawString + + if isGroupChat(chat) { + convType = "group_chat" + if chat.Subject.Valid { + title = chat.Subject.String + } + } else { + convType = "direct_chat" + // No title for direct chats (resolved via participant lookup) + } + + return sourceConvID, convType, title +} + +// mapMessage maps a WhatsApp message to a msgvault Message struct. +// The conversationID and sourceID must be resolved before calling. +func mapMessage(msg waMessage, conversationID, sourceID int64, senderID sql.NullInt64) store.Message { + sentAt := sql.NullTime{} + if msg.Timestamp > 0 { + // WhatsApp timestamps are in milliseconds since epoch. + sentAt = sql.NullTime{ + Time: time.Unix(msg.Timestamp/1000, (msg.Timestamp%1000)*1e6), + Valid: true, + } + } + + snippet := sql.NullString{} + if msg.TextData.Valid && msg.TextData.String != "" { + s := msg.TextData.String + if utf8.RuneCountInString(s) > 100 { + // Truncate to 100 runes, preserving multi-byte characters. + runes := []rune(s) + s = string(runes[:100]) + } + snippet = sql.NullString{String: s, Valid: true} + } + + return store.Message{ + ConversationID: conversationID, + SourceID: sourceID, + SourceMessageID: msg.KeyID, + MessageType: "whatsapp", + SentAt: sentAt, + SenderID: senderID, + IsFromMe: msg.FromMe == 1, + Snippet: snippet, + HasAttachments: isMediaType(msg.MessageType), + AttachmentCount: boolToInt(isMediaType(msg.MessageType)), + ArchivedAt: time.Now(), + } +} + +// mapMediaType maps a WhatsApp message_type integer to a media type string. +// Returns empty string for non-media types. +func mapMediaType(waMessageType int) string { + switch waMessageType { + case 1: + return "image" + case 2: + return "video" + case 3: + return "audio" + case 4: + return "gif" + case 5: + return "voice_note" + case 13: + return "document" + case 90: + return "sticker" + default: + return "" + } +} + +// isMediaType returns true if the WhatsApp message_type represents media. +func isMediaType(waMessageType int) bool { + return mapMediaType(waMessageType) != "" +} + +// isSkippedType returns true if the message type should be skipped during import. +// System messages, calls, locations, contacts, and polls are not imported. +func isSkippedType(waMessageType int) bool { + switch waMessageType { + case 7: // system message + return true + case 9: // location share + return true + case 10: // contact card + return true + case 15: // voice/video call + return true + case 64: // call (missed) + return true + case 66: // call (group) + return true + case 99: // poll + return true + case 11: // status/story + return true + default: + return false + } +} + +// normalizePhone normalizes a WhatsApp JID user+server to an E.164 phone number. +// Input: user="447700900000", server="s.whatsapp.net" +// Output: "+447700900000" +// Returns empty string for non-phone JIDs (e.g., lid:..., status@broadcast). +func normalizePhone(user, server string) string { + if user == "" { + return "" + } + + // Strip the @server suffix if present in user. + user = strings.TrimSuffix(user, "@"+server) + + // Already in E.164 format? + if strings.HasPrefix(user, "+") { + return user + } + + // Reject non-numeric JID users (e.g., "lid:123", "status", broadcast addresses). + // Valid phone numbers contain only digits. + for _, c := range user { + if c < '0' || c > '9' { + return "" + } + } + + // Must be at least a few digits to be a plausible phone number, + // and no more than 15 (E.164 max) to prevent data pollution. + if len(user) < 4 || len(user) > 15 { + return "" + } + + // Prepend + for E.164. + return "+" + user +} + +// resolveLidSender resolves a "lid" JID sender to a phone number via the +// jid_map lookup table. Only activates when the sender's server is "lid". +// Returns a normalised E.164 phone number, or empty string if unresolvable. +func resolveLidSender(jidRowID sql.NullInt64, server string, lidMap map[int64]waLidMapping) string { + if server != "lid" || !jidRowID.Valid { + return "" + } + mapping, ok := lidMap[jidRowID.Int64] + if !ok { + return "" + } + return normalizePhone(mapping.PhoneUser, mapping.PhoneServer) +} + +// mapReaction maps a WhatsApp reaction to reaction_type and reaction_value. +func mapReaction(r waReaction) (reactionType, reactionValue string) { + if r.ReactionValue.Valid && r.ReactionValue.String != "" { + return "emoji", r.ReactionValue.String + } + return "emoji", "" +} + +// mapGroupRole maps a WhatsApp admin level to a conversation participant role. +func mapGroupRole(admin int) string { + switch admin { + case 1: + return "admin" + case 2: + return "admin" // superadmin → admin (msgvault doesn't distinguish) + default: + return "member" + } +} + +// chatTitle returns a display title for a chat for progress reporting. +func chatTitle(chat waChat) string { + if chat.Subject.Valid && chat.Subject.String != "" { + return chat.Subject.String + } + if chat.User != "" { + return normalizePhone(chat.User, chat.Server) + } + return chat.RawString +} + +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} diff --git a/internal/whatsapp/mapping_test.go b/internal/whatsapp/mapping_test.go new file mode 100644 index 00000000..9b0b82d6 --- /dev/null +++ b/internal/whatsapp/mapping_test.go @@ -0,0 +1,347 @@ +package whatsapp + +import ( + "database/sql" + "testing" +) + +func TestNormalizePhone(t *testing.T) { + tests := []struct { + user, server string + want string + }{ + {"447700900000", "s.whatsapp.net", "+447700900000"}, + {"12025551234", "s.whatsapp.net", "+12025551234"}, + {"+447700900000", "s.whatsapp.net", "+447700900000"}, + {"", "s.whatsapp.net", ""}, + {"447700900000", "g.us", "+447700900000"}, + } + + for _, tt := range tests { + got := normalizePhone(tt.user, tt.server) + if got != tt.want { + t.Errorf("normalizePhone(%q, %q) = %q, want %q", tt.user, tt.server, got, tt.want) + } + } +} + +func TestMapMediaType(t *testing.T) { + tests := []struct { + waType int + want string + }{ + {0, ""}, // text + {1, "image"}, + {2, "video"}, + {3, "audio"}, + {4, "gif"}, + {5, "voice_note"}, + {13, "document"}, + {90, "sticker"}, + {7, ""}, // system (no media type) + {15, ""}, // call + {99, ""}, // poll + } + + for _, tt := range tests { + got := mapMediaType(tt.waType) + if got != tt.want { + t.Errorf("mapMediaType(%d) = %q, want %q", tt.waType, got, tt.want) + } + } +} + +func TestIsMediaType(t *testing.T) { + if !isMediaType(1) { + t.Error("isMediaType(1) should be true (image)") + } + if isMediaType(0) { + t.Error("isMediaType(0) should be false (text)") + } + if isMediaType(7) { + t.Error("isMediaType(7) should be false (system)") + } +} + +func TestIsSkippedType(t *testing.T) { + skipped := []int{7, 9, 10, 15, 64, 66, 99, 11} + for _, typ := range skipped { + if !isSkippedType(typ) { + t.Errorf("isSkippedType(%d) should be true", typ) + } + } + + notSkipped := []int{0, 1, 2, 3, 4, 5, 13, 90} + for _, typ := range notSkipped { + if isSkippedType(typ) { + t.Errorf("isSkippedType(%d) should be false", typ) + } + } +} + +func TestIsGroupChat(t *testing.T) { + tests := []struct { + name string + chat waChat + want bool + }{ + { + name: "direct chat", + chat: waChat{Server: "s.whatsapp.net", GroupType: 0}, + want: false, + }, + { + name: "standard group", + chat: waChat{Server: "g.us", GroupType: 1}, + want: true, + }, + { + name: "community sub-group (g.us + type=0)", + chat: waChat{Server: "g.us", GroupType: 0}, + want: true, + }, + { + name: "broadcast", + chat: waChat{Server: "broadcast", GroupType: 0}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isGroupChat(tt.chat) + if got != tt.want { + t.Errorf("isGroupChat() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMapConversation(t *testing.T) { + // Direct chat. + direct := waChat{ + RawString: "447700900000@s.whatsapp.net", + GroupType: 0, + } + id, typ, title := mapConversation(direct) + if id != "447700900000@s.whatsapp.net" { + t.Errorf("direct chat sourceConvID = %q, want %q", id, "447700900000@s.whatsapp.net") + } + if typ != "direct_chat" { + t.Errorf("direct chat convType = %q, want %q", typ, "direct_chat") + } + if title != "" { + t.Errorf("direct chat title = %q, want empty", title) + } + + // Group chat. + group := waChat{ + RawString: "120363001234567890@g.us", + Server: "g.us", + GroupType: 1, + Subject: sql.NullString{String: "Family Group", Valid: true}, + } + id, typ, title = mapConversation(group) + if id != "120363001234567890@g.us" { + t.Errorf("group chat sourceConvID = %q", id) + } + if typ != "group_chat" { + t.Errorf("group chat convType = %q, want %q", typ, "group_chat") + } + if title != "Family Group" { + t.Errorf("group chat title = %q, want %q", title, "Family Group") + } + + // Group with group_type=0 but g.us server (e.g. WhatsApp Community sub-groups). + community := waChat{ + RawString: "120363377259312783@g.us", + Server: "g.us", + GroupType: 0, + Subject: sql.NullString{String: "AI Impact", Valid: true}, + } + id, typ, title = mapConversation(community) + if typ != "group_chat" { + t.Errorf("g.us with group_type=0: convType = %q, want %q", typ, "group_chat") + } + if title != "AI Impact" { + t.Errorf("g.us with group_type=0: title = %q, want %q", title, "AI Impact") + } +} + +func TestMapMessage(t *testing.T) { + msg := waMessage{ + RowID: 42, + ChatRowID: 1, + FromMe: 1, + KeyID: "ABC123", + Timestamp: 1700000000000, // ms + MessageType: 0, + TextData: sql.NullString{String: "Hello world", Valid: true}, + } + + senderID := sql.NullInt64{Int64: 99, Valid: true} + result := mapMessage(msg, 10, 20, senderID) + + if result.ConversationID != 10 { + t.Errorf("ConversationID = %d, want 10", result.ConversationID) + } + if result.SourceID != 20 { + t.Errorf("SourceID = %d, want 20", result.SourceID) + } + if result.SourceMessageID != "ABC123" { + t.Errorf("SourceMessageID = %q, want %q", result.SourceMessageID, "ABC123") + } + if result.MessageType != "whatsapp" { + t.Errorf("MessageType = %q, want %q", result.MessageType, "whatsapp") + } + if !result.IsFromMe { + t.Error("IsFromMe should be true") + } + if !result.SentAt.Valid { + t.Error("SentAt should be valid") + } + if result.SentAt.Time.Unix() != 1700000000 { + t.Errorf("SentAt Unix = %d, want 1700000000", result.SentAt.Time.Unix()) + } + if !result.Snippet.Valid || result.Snippet.String != "Hello world" { + t.Errorf("Snippet = %v, want 'Hello world'", result.Snippet) + } + if result.HasAttachments { + t.Error("HasAttachments should be false for text message") + } +} + +func TestMapMessageSnippetTruncation(t *testing.T) { + // Create a message with text longer than 100 characters. + longText := "" + for i := 0; i < 150; i++ { + longText += "x" + } + + msg := waMessage{ + KeyID: "LONG1", + Timestamp: 1700000000000, + MessageType: 0, + TextData: sql.NullString{String: longText, Valid: true}, + } + + result := mapMessage(msg, 1, 1, sql.NullInt64{}) + if !result.Snippet.Valid { + t.Fatal("Snippet should be valid") + } + if len([]rune(result.Snippet.String)) != 100 { + t.Errorf("Snippet rune count = %d, want 100", len([]rune(result.Snippet.String))) + } +} + +func TestMapGroupRole(t *testing.T) { + tests := []struct { + admin int + want string + }{ + {0, "member"}, + {1, "admin"}, + {2, "admin"}, // superadmin + {3, "member"}, + } + + for _, tt := range tests { + got := mapGroupRole(tt.admin) + if got != tt.want { + t.Errorf("mapGroupRole(%d) = %q, want %q", tt.admin, got, tt.want) + } + } +} + +func TestMapReaction(t *testing.T) { + r := waReaction{ + ReactionValue: sql.NullString{String: "❤️", Valid: true}, + } + typ, val := mapReaction(r) + if typ != "emoji" { + t.Errorf("reaction type = %q, want %q", typ, "emoji") + } + if val != "❤️" { + t.Errorf("reaction value = %q, want %q", val, "❤️") + } + + // Empty reaction. + empty := waReaction{ + ReactionValue: sql.NullString{}, + } + _, val = mapReaction(empty) + if val != "" { + t.Errorf("empty reaction value = %q, want empty", val) + } +} + +func TestResolveLidSender(t *testing.T) { + lidMap := map[int64]waLidMapping{ + 100: {LidRowID: 100, PhoneUser: "447957366403", PhoneServer: "s.whatsapp.net"}, + 200: {LidRowID: 200, PhoneUser: "12025551234", PhoneServer: "s.whatsapp.net"}, + } + + tests := []struct { + name string + jidRowID sql.NullInt64 + server string + want string + }{ + { + name: "lid sender found in map", + jidRowID: sql.NullInt64{Int64: 100, Valid: true}, + server: "lid", + want: "+447957366403", + }, + { + name: "lid sender not in map", + jidRowID: sql.NullInt64{Int64: 999, Valid: true}, + server: "lid", + want: "", + }, + { + name: "non-lid server ignored", + jidRowID: sql.NullInt64{Int64: 100, Valid: true}, + server: "s.whatsapp.net", + want: "", + }, + { + name: "null jid row id", + jidRowID: sql.NullInt64{Valid: false}, + server: "lid", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveLidSender(tt.jidRowID, tt.server, lidMap) + if got != tt.want { + t.Errorf("resolveLidSender() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestChatTitle(t *testing.T) { + // Group with subject. + group := waChat{ + Subject: sql.NullString{String: "Work Chat", Valid: true}, + User: "120363001234567890", + Server: "g.us", + RawString: "120363001234567890@g.us", + } + if chatTitle(group) != "Work Chat" { + t.Errorf("chatTitle(group) = %q, want %q", chatTitle(group), "Work Chat") + } + + // Direct chat. + direct := waChat{ + User: "447700900000", + Server: "s.whatsapp.net", + RawString: "447700900000@s.whatsapp.net", + } + if chatTitle(direct) != "+447700900000" { + t.Errorf("chatTitle(direct) = %q, want %q", chatTitle(direct), "+447700900000") + } +} diff --git a/internal/whatsapp/queries.go b/internal/whatsapp/queries.go new file mode 100644 index 00000000..cb8f6e5c --- /dev/null +++ b/internal/whatsapp/queries.go @@ -0,0 +1,365 @@ +package whatsapp + +import ( + "database/sql" + "fmt" + "strings" +) + +// fetchChats returns all non-hidden chats from the WhatsApp database. +// Joins with the jid table to get JID details for each chat. +func fetchChats(db *sql.DB) ([]waChat, error) { + rows, err := db.Query(` + SELECT + c._id, + c.jid_row_id, + j.raw_string, + COALESCE(j.user, ''), + COALESCE(j.server, ''), + c.subject, + COALESCE(c.group_type, 0), + COALESCE(c.hidden, 0), + COALESCE(c.sort_timestamp, 0) + FROM chat c + JOIN jid j ON c.jid_row_id = j._id + WHERE COALESCE(c.hidden, 0) = 0 + ORDER BY c.sort_timestamp DESC + `) + if err != nil { + return nil, fmt.Errorf("fetch chats: %w", err) + } + defer rows.Close() + + var chats []waChat + for rows.Next() { + var c waChat + if err := rows.Scan( + &c.RowID, &c.JIDRowID, &c.RawString, &c.User, &c.Server, + &c.Subject, &c.GroupType, &c.Hidden, + &c.LastMessageTimestamp, + ); err != nil { + return nil, fmt.Errorf("scan chat: %w", err) + } + chats = append(chats, c) + } + return chats, rows.Err() +} + +// fetchMessages returns messages for a chat, batched after a given _id. +// Messages are ordered by _id ascending for deterministic resumability. +// Joins with jid to resolve sender information. +func fetchMessages(db *sql.DB, chatRowID int64, afterID int64, limit int) ([]waMessage, error) { + rows, err := db.Query(` + SELECT + m._id, + m.chat_row_id, + COALESCE(m.from_me, 0), + COALESCE(m.key_id, ''), + m.sender_jid_row_id, + sj.raw_string, + sj.user, + sj.server, + COALESCE(m.timestamp, 0), + COALESCE(m.message_type, 0), + m.text_data, + COALESCE(m.status, 0), + COALESCE(m.starred, 0) + FROM message m + LEFT JOIN jid sj ON m.sender_jid_row_id = sj._id + WHERE m.chat_row_id = ? + AND m._id > ? + ORDER BY m._id ASC + LIMIT ? + `, chatRowID, afterID, limit) + if err != nil { + return nil, fmt.Errorf("fetch messages: %w", err) + } + defer rows.Close() + + var messages []waMessage + for rows.Next() { + var m waMessage + if err := rows.Scan( + &m.RowID, &m.ChatRowID, &m.FromMe, &m.KeyID, + &m.SenderJIDRowID, &m.SenderRawString, &m.SenderUser, &m.SenderServer, + &m.Timestamp, &m.MessageType, &m.TextData, + &m.Status, &m.Starred, + ); err != nil { + return nil, fmt.Errorf("scan message: %w", err) + } + messages = append(messages, m) + } + return messages, rows.Err() +} + +// fetchMedia returns media metadata for a batch of message row IDs. +// Returns a map of message_row_id → waMedia. +func fetchMedia(db *sql.DB, messageRowIDs []int64) (map[int64]waMedia, error) { + if len(messageRowIDs) == 0 { + return make(map[int64]waMedia), nil + } + + result := make(map[int64]waMedia) + + // Process in chunks to stay within SQLite's parameter limit. + const chunkSize = 500 + for i := 0; i < len(messageRowIDs); i += chunkSize { + end := i + chunkSize + if end > len(messageRowIDs) { + end = len(messageRowIDs) + } + chunk := messageRowIDs[i:end] + + placeholders := make([]string, len(chunk)) + args := make([]interface{}, len(chunk)) + for j, id := range chunk { + placeholders[j] = "?" + args[j] = id + } + + query := fmt.Sprintf(` + SELECT + mm.message_row_id, + mm.mime_type, + mm.media_caption, + mm.file_size, + mm.file_path, + mm.width, + mm.height, + mm.media_duration + FROM message_media mm + WHERE mm.message_row_id IN (%s) + `, strings.Join(placeholders, ",")) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("fetch media: %w", err) + } + + for rows.Next() { + var m waMedia + if err := rows.Scan( + &m.MessageRowID, &m.MimeType, &m.MediaCaption, + &m.FileSize, &m.FilePath, &m.Width, &m.Height, + &m.MediaDuration, + ); err != nil { + rows.Close() + return nil, fmt.Errorf("scan media: %w", err) + } + result[m.MessageRowID] = m + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + } + + return result, nil +} + +// fetchReactions returns reactions for a batch of message row IDs. +// Returns a map of message_row_id → []waReaction. +func fetchReactions(db *sql.DB, messageRowIDs []int64) (map[int64][]waReaction, error) { + if len(messageRowIDs) == 0 { + return make(map[int64][]waReaction), nil + } + + result := make(map[int64][]waReaction) + + const chunkSize = 500 + for i := 0; i < len(messageRowIDs); i += chunkSize { + end := i + chunkSize + if end > len(messageRowIDs) { + end = len(messageRowIDs) + } + chunk := messageRowIDs[i:end] + + placeholders := make([]string, len(chunk)) + args := make([]interface{}, len(chunk)) + for j, id := range chunk { + placeholders[j] = "?" + args[j] = id + } + + // WhatsApp stores reactions in message_add_on (metadata) joined with + // message_add_on_reaction (the actual emoji). The link to the original + // message is via parent_message_row_id. + query := fmt.Sprintf(` + SELECT + ao.parent_message_row_id, + ao.sender_jid_row_id, + sj.raw_string, + sj.user, + sj.server, + ar.reaction, + COALESCE(ar.sender_timestamp, 0) + FROM message_add_on ao + JOIN message_add_on_reaction ar ON ar.message_add_on_row_id = ao._id + LEFT JOIN jid sj ON ao.sender_jid_row_id = sj._id + WHERE ao.parent_message_row_id IN (%s) + AND ar.reaction IS NOT NULL + AND ar.reaction != '' + `, strings.Join(placeholders, ",")) + + rows, err := db.Query(query, args...) + if err != nil { + // Table might not exist in older DB versions + if isTableNotFound(err) { + return result, nil + } + return nil, fmt.Errorf("fetch reactions: %w", err) + } + + for rows.Next() { + var r waReaction + if err := rows.Scan( + &r.MessageRowID, &r.SenderJIDRowID, + &r.SenderRawString, &r.SenderUser, &r.SenderServer, + &r.ReactionValue, &r.Timestamp, + ); err != nil { + rows.Close() + return nil, fmt.Errorf("scan reaction: %w", err) + } + result[r.MessageRowID] = append(result[r.MessageRowID], r) + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + } + + return result, nil +} + +// fetchGroupParticipants returns all participants of a group chat. +// In the WhatsApp schema, group_participants.gjid and .jid are TEXT fields +// containing raw JID strings (e.g., "447700900000@s.whatsapp.net"), +// not integer row IDs. +func fetchGroupParticipants(db *sql.DB, groupJIDRawString string) ([]waGroupMember, error) { + rows, err := db.Query(` + SELECT + gp.gjid, + gp.jid, + COALESCE(j.user, ''), + COALESCE(j.server, ''), + COALESCE(gp.admin, 0) + FROM group_participants gp + LEFT JOIN jid j ON gp.jid = j.raw_string + WHERE gp.gjid = ? + `, groupJIDRawString) + if err != nil { + return nil, fmt.Errorf("fetch group participants: %w", err) + } + defer rows.Close() + + var members []waGroupMember + for rows.Next() { + var m waGroupMember + if err := rows.Scan( + &m.GroupJID, &m.MemberJID, + &m.MemberUser, &m.MemberServer, &m.Admin, + ); err != nil { + return nil, fmt.Errorf("scan group participant: %w", err) + } + members = append(members, m) + } + return members, rows.Err() +} + +// fetchQuotedMessages returns quoted message references for a batch of message row IDs. +// Returns a map of message_row_id → waQuoted (the message that contains the quote). +func fetchQuotedMessages(db *sql.DB, messageRowIDs []int64) (map[int64]waQuoted, error) { + if len(messageRowIDs) == 0 { + return make(map[int64]waQuoted), nil + } + + result := make(map[int64]waQuoted) + + const chunkSize = 500 + for i := 0; i < len(messageRowIDs); i += chunkSize { + end := i + chunkSize + if end > len(messageRowIDs) { + end = len(messageRowIDs) + } + chunk := messageRowIDs[i:end] + + placeholders := make([]string, len(chunk)) + args := make([]interface{}, len(chunk)) + for j, id := range chunk { + placeholders[j] = "?" + args[j] = id + } + + query := fmt.Sprintf(` + SELECT + mq.message_row_id, + mq.key_id + FROM message_quoted mq + WHERE mq.message_row_id IN (%s) + AND mq.key_id IS NOT NULL + AND mq.key_id != '' + `, strings.Join(placeholders, ",")) + + rows, err := db.Query(query, args...) + if err != nil { + if isTableNotFound(err) { + return result, nil + } + return nil, fmt.Errorf("fetch quoted messages: %w", err) + } + + for rows.Next() { + var q waQuoted + if err := rows.Scan(&q.MessageRowID, &q.QuotedKeyID); err != nil { + rows.Close() + return nil, fmt.Errorf("scan quoted message: %w", err) + } + result[q.MessageRowID] = q + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + } + + return result, nil +} + +// fetchLidMap reads the WhatsApp jid_map table to build a mapping from +// lid JID row IDs to their corresponding phone JIDs. The jid_map table +// links two jid rows: one with server="lid" and one with the real phone +// (server="s.whatsapp.net"). Returns an empty map gracefully if the +// jid_map table doesn't exist (older WhatsApp DB versions). +func fetchLidMap(db *sql.DB) (map[int64]waLidMapping, error) { + result := make(map[int64]waLidMapping) + + rows, err := db.Query(` + SELECT + jm.lid_row_id, + COALESCE(phone_jid.user, ''), + COALESCE(phone_jid.server, '') + FROM jid_map jm + JOIN jid phone_jid ON jm.jid_row_id = phone_jid._id + `) + if err != nil { + if isTableNotFound(err) { + return result, nil + } + return nil, fmt.Errorf("fetch lid map: %w", err) + } + defer rows.Close() + + for rows.Next() { + var m waLidMapping + if err := rows.Scan(&m.LidRowID, &m.PhoneUser, &m.PhoneServer); err != nil { + return nil, fmt.Errorf("scan lid mapping: %w", err) + } + result[m.LidRowID] = m + } + return result, rows.Err() +} + +// isTableNotFound returns true if the error indicates a missing table. +func isTableNotFound(err error) bool { + return err != nil && strings.Contains(err.Error(), "no such table") +} diff --git a/internal/whatsapp/queries_test.go b/internal/whatsapp/queries_test.go new file mode 100644 index 00000000..2f02a4d7 --- /dev/null +++ b/internal/whatsapp/queries_test.go @@ -0,0 +1,89 @@ +package whatsapp + +import ( + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestFetchLidMap(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Create the jid and jid_map tables matching WhatsApp's actual schema. + // In WhatsApp: jid_map.lid_row_id is PK (= jid._id for the lid entry), + // jid_map.jid_row_id points to the phone jid._id. + _, err = db.Exec(` + CREATE TABLE jid ( + _id INTEGER PRIMARY KEY, + user TEXT, + server TEXT, + raw_string TEXT + ); + CREATE TABLE jid_map ( + lid_row_id INTEGER PRIMARY KEY NOT NULL, + jid_row_id INTEGER NOT NULL + ); + + -- lid JID entries (these are the lid_row_id values) + INSERT INTO jid (_id, user, server, raw_string) VALUES (10, '12345abcde', 'lid', '12345abcde@lid'); + INSERT INTO jid (_id, user, server, raw_string) VALUES (20, '67890fghij', 'lid', '67890fghij@lid'); + + -- phone JID entries (these are the jid_row_id values) + INSERT INTO jid (_id, user, server, raw_string) VALUES (11, '447957366403', 's.whatsapp.net', '447957366403@s.whatsapp.net'); + INSERT INTO jid (_id, user, server, raw_string) VALUES (21, '12025551234', 's.whatsapp.net', '12025551234@s.whatsapp.net'); + + -- Map lid → phone + INSERT INTO jid_map (lid_row_id, jid_row_id) VALUES (10, 11); + INSERT INTO jid_map (lid_row_id, jid_row_id) VALUES (20, 21); + `) + if err != nil { + t.Fatal(err) + } + + lidMap, err := fetchLidMap(db) + if err != nil { + t.Fatal(err) + } + + if len(lidMap) != 2 { + t.Fatalf("expected 2 lid mappings, got %d", len(lidMap)) + } + + m1, ok := lidMap[10] + if !ok { + t.Fatal("expected lid row 10 in map") + } + if m1.PhoneUser != "447957366403" || m1.PhoneServer != "s.whatsapp.net" { + t.Errorf("lid 10: got user=%q server=%q, want 447957366403@s.whatsapp.net", m1.PhoneUser, m1.PhoneServer) + } + + m2, ok := lidMap[20] + if !ok { + t.Fatal("expected lid row 20 in map") + } + if m2.PhoneUser != "12025551234" { + t.Errorf("lid 20: got user=%q, want 12025551234", m2.PhoneUser) + } +} + +func TestFetchLidMapMissingTable(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // No jid_map table — should return empty map, not error. + lidMap, err := fetchLidMap(db) + if err != nil { + t.Fatalf("expected no error for missing table, got: %v", err) + } + if len(lidMap) != 0 { + t.Errorf("expected empty map, got %d entries", len(lidMap)) + } +} diff --git a/internal/whatsapp/types.go b/internal/whatsapp/types.go new file mode 100644 index 00000000..1fbc74b8 --- /dev/null +++ b/internal/whatsapp/types.go @@ -0,0 +1,155 @@ +// Package whatsapp provides import functionality for WhatsApp message backups. +// It reads from a decrypted WhatsApp msgstore.db (SQLite) and maps messages +// into the msgvault unified schema. +package whatsapp + +import ( + "database/sql" + "time" +) + +// waChat represents a chat/conversation from the WhatsApp jid + chat tables. +type waChat struct { + RowID int64 // chat._id + JIDRowID int64 // chat.jid_row_id → jid._id + RawString string // jid.raw_string (e.g., "447700900000@s.whatsapp.net") + User string // jid.user (phone number part) + Server string // jid.server (s.whatsapp.net or g.us) + Subject sql.NullString // chat.subject (group name) + GroupType int // chat.group_type: 0=individual (but see Server), >0=group + Hidden int // chat.hidden + LastMessageTimestamp int64 // chat.sort_timestamp +} + +// waMessage represents a message from the WhatsApp message table. +type waMessage struct { + RowID int64 // message._id + ChatRowID int64 // message.chat_row_id + FromMe int // message.from_me (0=received, 1=sent) + KeyID string // message.key_id (unique message ID) + SenderJIDRowID sql.NullInt64 // message.sender_jid_row_id → jid._id + SenderRawString sql.NullString // jid.raw_string of sender + SenderUser sql.NullString // jid.user of sender + SenderServer sql.NullString // jid.server of sender + Timestamp int64 // message.timestamp (ms since epoch) + MessageType int // message.message_type + TextData sql.NullString // message.text_data + Status int // message.status + Starred int // message.starred +} + +// waMedia represents media metadata from the message_media table. +type waMedia struct { + MessageRowID int64 // message_media.message_row_id + MimeType sql.NullString // message_media.mime_type + MediaCaption sql.NullString // message_media.media_caption + FileSize sql.NullInt64 // message_media.file_size + FilePath sql.NullString // message_media.file_path + Width sql.NullInt64 // message_media.width + Height sql.NullInt64 // message_media.height + MediaDuration sql.NullInt64 // message_media.media_duration (seconds) +} + +// waReaction represents a reaction from the message_add_on table. +type waReaction struct { + MessageRowID int64 // FK to message._id + SenderJIDRowID sql.NullInt64 // jid of reactor + SenderRawString sql.NullString // jid.raw_string + SenderUser sql.NullString // jid.user + SenderServer sql.NullString // jid.server + ReactionValue sql.NullString // emoji character + Timestamp int64 // timestamp (ms) +} + +// waGroupMember represents a member of a group chat. +type waGroupMember struct { + GroupJID string // group_participants.gjid (text, raw JID string) + MemberJID string // group_participants.jid (text, raw JID string) + MemberUser string // jid.user (parsed from MemberJID) + MemberServer string // jid.server (parsed from MemberJID) + Admin int // group_participants.admin (0=member, 1=admin, 2=superadmin) +} + +// waQuoted represents a quoted/replied-to message reference. +type waQuoted struct { + MessageRowID int64 // the message that quotes + QuotedKeyID string // message_quoted.key_id of the quoted message +} + +// waLidMapping maps a "lid" JID row to its corresponding phone JID, +// populated from the WhatsApp jid_map table. +type waLidMapping struct { + LidRowID int64 // jid._id for the lid entry + PhoneUser string // jid.user for the phone entry (e.g., "447700900000") + PhoneServer string // jid.server for the phone entry (e.g., "s.whatsapp.net") +} + +// ImportOptions configures the WhatsApp import process. +type ImportOptions struct { + // Phone is the user's own phone number in E.164 format (e.g., "+447700900000"). + Phone string + + // DisplayName is an optional display name for the user. + DisplayName string + + // MediaDir is an optional path to the decrypted Media folder. + // If set, media files will be copied to content-addressed storage. + MediaDir string + + // AttachmentsDir is the root directory for content-addressed attachment storage. + // This should be cfg.AttachmentsDir() (e.g., ~/.msgvault/attachments/). + // Required when MediaDir is set. + AttachmentsDir string + + // MaxMediaFileSize is the maximum size of a single media file to copy (in bytes). + // Files larger than this are skipped. Default: 100MB. + MaxMediaFileSize int64 + + // Limit limits the number of messages imported (0 = no limit, for testing). + Limit int + + // BatchSize is the number of messages to process per batch (default: 1000). + BatchSize int +} + +// DefaultOptions returns ImportOptions with sensible defaults. +func DefaultOptions() ImportOptions { + return ImportOptions{ + BatchSize: 1000, + MaxMediaFileSize: 100 * 1024 * 1024, // 100MB + } +} + +// ImportSummary holds statistics from a completed import. +type ImportSummary struct { + Duration time.Duration + ChatsProcessed int64 + MessagesProcessed int64 + MessagesAdded int64 + MessagesSkipped int64 + ReactionsAdded int64 + AttachmentsFound int64 + MediaCopied int64 + Participants int64 + Errors int64 +} + +// ImportProgress provides callbacks for import progress reporting. +type ImportProgress interface { + OnStart() + OnChatStart(chatJID string, chatTitle string, messageCount int) + OnProgress(processed, added, skipped int64) + OnChatComplete(chatJID string, messagesAdded int64) + OnComplete(summary *ImportSummary) + OnError(err error) +} + +// NullProgress is a no-op implementation of ImportProgress. +type NullProgress struct{} + +func (NullProgress) OnStart() {} +func (NullProgress) OnChatStart(string, string, int) {} +func (NullProgress) OnProgress(int64, int64, int64) {} +func (NullProgress) OnChatComplete(string, int64) {} +func (NullProgress) OnComplete(*ImportSummary) {} +func (NullProgress) OnError(error) {}