From 75c6119617b5fcef2b02e9af9ebadffd413797cf Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 14:58:52 -0700 Subject: [PATCH 1/7] Add SSRF-safe HTTP client with --unsafe-client escape hatch - Use DataDog go-secure-sdk for SSRF protection (blocks private IPs, loopback, link-local, cloud metadata endpoints) - Refactor rss.Fetcher and scraper.Scraper as structs holding *http.Client - Scanner accepts Fetcher + Scraper as dependencies (no global state) - Add --unsafe-client / BLOGWATCHER_UNSAFE_CLIENT flag for local dev - ScanBlog now returns error (hard fail, not soft error string) - E2e test verifies safe client blocks loopback requests - Tests use plain http.Client (no SSRF check needed for httptest) - Update AGENTS.md: ground claims in evidence, not guesses Co-Authored-By: Claude Opus 4.6 (1M context) --- AGENTS.md | 4 ++ e2e/e2e_test.go | 28 ++++++++-- go.mod | 2 +- go.sum | 12 +++-- internal/cli/commands.go | 24 ++++++--- internal/cli/root.go | 1 + internal/rss/rss.go | 27 ++++++---- internal/rss/rss_test.go | 8 ++- internal/scanner/scanner.go | 89 ++++++++++++++++++++------------ internal/scanner/scanner_test.go | 53 ++++++++++++++----- internal/scraper/scraper.go | 15 ++++-- internal/scraper/scraper_test.go | 6 ++- 12 files changed, 193 insertions(+), 76 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index d6c29cf..855e854 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,3 +1,7 @@ +### Before speaking + +Never guess. Ground every claim in actual files, docs, or command output. If you're unsure about a dependency's size, run `go get` and check. If you're unsure about an API, read the source or docs. If you're unsure about behavior, write a test. State what you verified and how — don't speculate. + ### While developing #### Tests, tests, tests! diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 29caf9b..0dbd930 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -57,11 +57,11 @@ func (c *cliOpts) run(t *testing.T, args []string, opts map[string]string) (stdo var cmdArgs []string var extraEnv []string - // DB path goes before the subcommand (persistent flag). + // Persistent flags / env vars. if c.mode == "flags" { - cmdArgs = append(cmdArgs, "--db", c.dbPath) + cmdArgs = append(cmdArgs, "--db", c.dbPath, "--unsafe-client") } else { - extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath) + extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath, "BLOGWATCHER_UNSAFE_CLIENT=true") } // Positional args first (includes the subcommand name). @@ -402,6 +402,28 @@ func TestE2E(t *testing.T) { } } +func TestSSRFProtection(t *testing.T) { + baseURL := startTestServer(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + // Add a blog pointing to the loopback test server WITHOUT --unsafe-client. + // The add command doesn't fetch, so it should succeed. + cmd := exec.CommandContext(context.Background(), binaryPath, + "--db", dbPath, "add", "test-blog", baseURL+"/go/", + "--feed-url", baseURL+"/go/feed.atom") + cmd.Env = append(os.Environ(), "NO_COLOR=1") + out, err := cmd.CombinedOutput() + require.NoError(t, err, "add should succeed: %s", string(out)) + + // Scan WITHOUT --unsafe-client — the safe client should block loopback and fail. + cmd = exec.CommandContext(context.Background(), binaryPath, + "--db", dbPath, "scan") + cmd.Env = append(os.Environ(), "NO_COLOR=1") + out, err = cmd.CombinedOutput() + require.Error(t, err, "scan should fail when SSRF protection blocks loopback") + require.Contains(t, string(out), "is not authorized", "expected SSRF error message") +} + func extractFirstID(t *testing.T, output string) string { t.Helper() re := regexp.MustCompile(`\[(\d+)\]`) diff --git a/go.mod b/go.mod index c2e1ea3..8b8cdd8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/JulienTant/blogwatcher-cli go 1.26.1 require ( + github.com/DataDog/go-secure-sdk v0.0.7 github.com/Masterminds/squirrel v1.5.4 github.com/PuerkitoBio/goquery v1.12.0 github.com/fatih/color v1.19.0 @@ -44,7 +45,6 @@ require ( golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/tools v0.43.0 // indirect - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index d244cad..83d1bb7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DataDog/go-secure-sdk v0.0.7 h1:FJYJPXXZmxOC0RHWzhsZ3PKtUaF+IgC6fDc/PbEoH6c= +github.com/DataDog/go-secure-sdk v0.0.7/go.mod h1:/fIdMM7LMp7KGouxN9Cnv3UP0NdfP+XZ5R8qyU3D8qk= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= github.com/PuerkitoBio/goquery v1.12.0 h1:pAcL4g3WRXekcB9AU/y1mbKez2dbY2AajVhtkO8RIBo= @@ -21,6 +23,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -66,8 +70,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= @@ -169,8 +173,8 @@ golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 3e4a504..e356cb6 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -3,21 +3,34 @@ package cli import ( "bufio" "fmt" + "net/http" "os" "strconv" "strings" "time" + "github.com/DataDog/go-secure-sdk/net/httpclient" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/JulienTant/blogwatcher-cli/internal/controller" "github.com/JulienTant/blogwatcher-cli/internal/model" + "github.com/JulienTant/blogwatcher-cli/internal/rss" "github.com/JulienTant/blogwatcher-cli/internal/scanner" + "github.com/JulienTant/blogwatcher-cli/internal/scraper" "github.com/JulienTant/blogwatcher-cli/internal/storage" ) +const httpTimeout = 30 * time.Second + +func newHTTPClient() *http.Client { + if viper.GetBool("unsafe-client") { + return httpclient.UnSafe(httpclient.WithTimeout(httpTimeout)) + } + return httpclient.Safe(httpclient.WithTimeout(httpTimeout)) +} + func newAddCommand() *cobra.Command { cmd := &cobra.Command{ Use: "add ", @@ -148,8 +161,11 @@ func newScanCommand() *cobra.Command { } }() + client := newHTTPClient() + sc := scanner.NewScanner(rss.NewFetcher(client), scraper.NewScraper(client)) + if len(args) == 1 { - result, err := scanner.ScanBlogByName(cmd.Context(), db, args[0]) + result, err := sc.ScanBlogByName(cmd.Context(), db, args[0]) if err != nil { return err } @@ -173,7 +189,7 @@ func newScanCommand() *cobra.Command { if !silent { cprintf([]color.Attribute{color.FgCyan}, "Scanning %d blog(s)...\n\n", len(blogs)) } - results, err := scanner.ScanAllBlogs(cmd.Context(), db, workers) + results, err := sc.ScanAllBlogs(cmd.Context(), db, workers) if err != nil { return err } @@ -385,10 +401,6 @@ func printScanResult(result scanner.ScanResult) { statusColor = []color.Attribute{color.FgGreen} } cprintf([]color.Attribute{color.FgWhite, color.Bold}, " %s\n", result.BlogName) - if result.Error != "" { - cprintfErr(color.FgRed, " Error: %s\n", result.Error) - return - } if result.Source == "none" { cprintln([]color.Attribute{color.FgYellow}, " No feed or scraper configured") return diff --git a/internal/cli/root.go b/internal/cli/root.go index f822f7e..161840a 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -24,6 +24,7 @@ func NewRootCommand() *cobra.Command { rootCmd.SetVersionTemplate("{{.Version}}\n") rootCmd.PersistentFlags().String("db", "", "Path to the SQLite database file (default: ~/.blogwatcher-cli/blogwatcher-cli.db)") + rootCmd.PersistentFlags().Bool("unsafe-client", false, "Disable SSRF protection (allow requests to private/loopback IPs)") rootCmd.AddCommand(newAddCommand()) rootCmd.AddCommand(newRemoveCommand()) diff --git a/internal/rss/rss.go b/internal/rss/rss.go index 1107eb9..87021e8 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -28,13 +28,22 @@ func (e FeedParseError) Error() string { return e.Message } -func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]FeedArticle, error) { - client := &http.Client{Timeout: timeout} +// Fetcher fetches and parses RSS/Atom feeds. +type Fetcher struct { + client *http.Client +} + +// NewFetcher creates a Fetcher with the given HTTP client. +func NewFetcher(client *http.Client) *Fetcher { + return &Fetcher{client: client} +} + +func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) if err != nil { return nil, FeedParseError{Message: fmt.Sprintf("failed to create request: %v", err)} } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return nil, FeedParseError{Message: fmt.Sprintf("failed to fetch feed: %v", err)} } @@ -70,13 +79,12 @@ func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]Fe return articles, nil } -func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) (string, error) { - client := &http.Client{Timeout: timeout} +func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { return "", nil } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return "", nil } @@ -138,7 +146,7 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) if resolved == "" { continue } - ok, err := isValidFeed(ctx, resolved, timeout) + ok, err := f.isValidFeed(ctx, resolved) if err == nil && ok { return resolved, nil } @@ -147,13 +155,12 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) return "", nil } -func isValidFeed(ctx context.Context, feedURL string, timeout time.Duration) (bool, error) { - client := &http.Client{Timeout: timeout} +func (f *Fetcher) isValidFeed(ctx context.Context, feedURL string) (bool, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) if err != nil { return false, err } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return false, err } diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go index 7fefca9..6dcd987 100644 --- a/internal/rss/rss_test.go +++ b/internal/rss/rss_test.go @@ -26,6 +26,10 @@ const sampleFeed = ` ` +func newTestFetcher() *Fetcher { + return NewFetcher(&http.Client{Timeout: 2 * time.Second}) +} + func TestParseFeed(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil { @@ -35,7 +39,7 @@ func TestParseFeed(t *testing.T) { })) defer server.Close() - articles, err := ParseFeed(context.Background(), server.URL, 2*time.Second) + articles, err := newTestFetcher().ParseFeed(context.Background(), server.URL) require.NoError(t, err, "parse feed") require.Len(t, articles, 2) require.NotNil(t, articles[0].PublishedDate) @@ -58,7 +62,7 @@ func TestDiscoverFeedURL(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - feedURL, err := DiscoverFeedURL(context.Background(), server.URL, 2*time.Second) + feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL) require.NoError(t, err, "discover feed") require.NotEmpty(t, feedURL, "expected feed url") } diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 7997588..5271dc5 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -17,50 +17,65 @@ type ScanResult struct { NewArticles int TotalFound int Source string - Error string } -func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanResult { +// Scanner orchestrates blog scanning using a Fetcher and Scraper. +type Scanner struct { + fetcher *rss.Fetcher + scraper *scraper.Scraper +} + +// NewScanner creates a Scanner with the given fetcher and scraper. +func NewScanner(fetcher *rss.Fetcher, scraper *scraper.Scraper) *Scanner { + return &Scanner{fetcher: fetcher, scraper: scraper} +} + +func (s *Scanner) ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) (ScanResult, error) { var ( articles []model.Article source = "none" - errText string ) feedURL := blog.FeedURL if feedURL == "" { - if discovered, err := rss.DiscoverFeedURL(ctx, blog.URL, 30*time.Second); err == nil && discovered != "" { + discovered, err := s.fetcher.DiscoverFeedURL(ctx, blog.URL) + if err != nil { + return ScanResult{BlogName: blog.Name}, err + } + if discovered != "" { feedURL = discovered blog.FeedURL = discovered if err := db.UpdateBlog(ctx, blog); err != nil { - fmt.Fprintf(os.Stderr, "update blog: %v\n", err) + return ScanResult{BlogName: blog.Name}, err } } } if feedURL != "" { - feedArticles, err := rss.ParseFeed(ctx, feedURL, 30*time.Second) + feedArticles, err := s.fetcher.ParseFeed(ctx, feedURL) if err != nil { - errText = err.Error() + // If there's a scraper fallback, try it before giving up. + if blog.ScrapeSelector == "" { + return ScanResult{BlogName: blog.Name}, err + } + // Try scraper as fallback. + scrapedArticles, scrapeErr := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector) + if scrapeErr != nil { + return ScanResult{BlogName: blog.Name}, fmt.Errorf("RSS: %w; Scraper: %w", err, scrapeErr) + } + articles = convertScrapedArticles(blog.ID, scrapedArticles) + source = "scraper" } else { articles = convertFeedArticles(blog.ID, feedArticles) source = "rss" } - } - - if len(articles) == 0 && blog.ScrapeSelector != "" { - scrapedArticles, err := scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector, 30*time.Second) + } else if blog.ScrapeSelector != "" { + scrapedArticles, err := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector) if err != nil { - if errText != "" { - errText = fmt.Sprintf("RSS: %s; Scraper: %s", errText, err.Error()) - } else { - errText = err.Error() - } - } else { - articles = convertScrapedArticles(blog.ID, scrapedArticles) - source = "scraper" - errText = "" + return ScanResult{BlogName: blog.Name}, err } + articles = convertScrapedArticles(blog.ID, scrapedArticles) + source = "scraper" } seenURLs := make(map[string]struct{}) @@ -80,7 +95,7 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe existing, err := db.GetExistingArticleURLs(ctx, urlList) if err != nil { - errText = err.Error() + return ScanResult{BlogName: blog.Name}, err } discoveredAt := time.Now() @@ -97,14 +112,13 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe if len(newArticles) > 0 { count, err := db.AddArticlesBulk(ctx, newArticles) if err != nil { - errText = err.Error() - } else { - newCount = count + return ScanResult{BlogName: blog.Name}, err } + newCount = count } if err := db.UpdateBlogLastScanned(ctx, blog.ID, time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "update last scanned: %v\n", err) + return ScanResult{BlogName: blog.Name}, err } return ScanResult{ @@ -112,11 +126,10 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe NewArticles: newCount, TotalFound: len(seenURLs), Source: source, - Error: errText, - } + }, nil } -func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) { +func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) { blogs, err := db.ListBlogs(ctx) if err != nil { return nil, err @@ -124,7 +137,11 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca if workers <= 1 { results := make([]ScanResult, 0, len(blogs)) for _, blog := range blogs { - results = append(results, ScanBlog(ctx, db, blog)) + result, err := s.ScanBlog(ctx, db, blog) + if err != nil { + return nil, fmt.Errorf("scan %s: %w", blog.Name, err) + } + results = append(results, result) } return results, nil } @@ -150,7 +167,12 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca } }() for item := range jobs { - results[item.Index] = ScanBlog(ctx, workerDB, item.Blog) + result, err := s.ScanBlog(ctx, workerDB, item.Blog) + if err != nil { + errs <- fmt.Errorf("scan %s: %w", item.Blog.Name, err) + return + } + results[item.Index] = result } errs <- nil }() @@ -170,7 +192,7 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca return results, nil } -func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) { +func (s *Scanner) ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) { blog, err := db.GetBlogByName(ctx, name) if err != nil { return nil, err @@ -178,7 +200,10 @@ func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*Sc if blog == nil { return nil, nil } - result := ScanBlog(ctx, db, *blog) + result, err := s.ScanBlog(ctx, db, *blog) + if err != nil { + return nil, err + } return &result, nil } diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go index 014bba7..dcf0f56 100644 --- a/internal/scanner/scanner_test.go +++ b/internal/scanner/scanner_test.go @@ -2,6 +2,7 @@ package scanner import ( "context" + "fmt" "net/http" "net/http/httptest" "path/filepath" @@ -9,6 +10,8 @@ import ( "time" "github.com/JulienTant/blogwatcher-cli/internal/model" + "github.com/JulienTant/blogwatcher-cli/internal/rss" + "github.com/JulienTant/blogwatcher-cli/internal/scraper" "github.com/JulienTant/blogwatcher-cli/internal/storage" "github.com/stretchr/testify/require" ) @@ -28,6 +31,11 @@ const sampleFeed = ` ` +func newTestScanner() *Scanner { + client := &http.Client{Timeout: 2 * time.Second} + return NewScanner(rss.NewFetcher(client), scraper.NewScraper(client)) +} + func TestScanBlogRSS(t *testing.T) { ctx := context.Background() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -44,7 +52,8 @@ func TestScanBlogRSS(t *testing.T) { blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com", FeedURL: server.URL}) require.NoError(t, err, "add blog") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, 2, result.NewArticles) require.Equal(t, "rss", result.Source) @@ -81,31 +90,46 @@ func TestScanBlogScraperFallback(t *testing.T) { blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: server.URL, FeedURL: server.URL + "/feed.xml", ScrapeSelector: "article h2 a"}) require.NoError(t, err, "add blog") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, "scraper", result.Source) require.Equal(t, 1, result.NewArticles) - require.Empty(t, result.Error) } func TestScanAllBlogsConcurrent(t *testing.T) { ctx := context.Background() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil { - http.Error(w, writeErr.Error(), http.StatusInternalServerError) - return - } - })) + + feedTemplate := ` +%s +Post 1https://%s.example.com/1 +Post 2https://%s.example.com/2 +` + + mux := http.NewServeMux() + for _, name := range []string{"a", "b"} { + feed := fmt.Sprintf(feedTemplate, name, name, name) + mux.HandleFunc("/"+name+"/feed", func(w http.ResponseWriter, r *http.Request) { + if _, writeErr := w.Write([]byte(feed)); writeErr != nil { + http.Error(w, writeErr.Error(), http.StatusInternalServerError) + } + }) + } + server := httptest.NewServer(mux) defer server.Close() db := openTestDB(t) defer func() { require.NoError(t, db.Close()) }() - for i, name := range []string{"TestA", "TestB"} { - _, err := db.AddBlog(ctx, model.Blog{Name: name, URL: "https://example.com/" + name, FeedURL: server.URL}) - require.NoError(t, err, "add blog %d", i) + for _, name := range []string{"a", "b"} { + _, err := db.AddBlog(ctx, model.Blog{ + Name: "Test-" + name, + URL: "https://" + name + ".example.com", + FeedURL: server.URL + "/" + name + "/feed", + }) + require.NoError(t, err, "add blog %s", name) } - results, err := ScanAllBlogs(ctx, db, 2) + results, err := newTestScanner().ScanAllBlogs(ctx, db, 2) require.NoError(t, err, "scan all blogs") require.Len(t, results, 2) } @@ -137,7 +161,8 @@ func TestScanBlogRespectsExistingArticles(t *testing.T) { _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "First", URL: "https://example.com/1", DiscoveredDate: ptrTime(time.Now())}) require.NoError(t, err, "add article") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, 1, result.NewArticles) } diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go index b56c1cb..8507bd6 100644 --- a/internal/scraper/scraper.go +++ b/internal/scraper/scraper.go @@ -27,13 +27,22 @@ func (e ScrapeError) Error() string { return e.Message } -func ScrapeBlog(ctx context.Context, blogURL string, selector string, timeout time.Duration) ([]ScrapedArticle, error) { - client := &http.Client{Timeout: timeout} +// Scraper scrapes HTML pages for article links. +type Scraper struct { + client *http.Client +} + +// NewScraper creates a Scraper with the given HTTP client. +func NewScraper(client *http.Client) *Scraper { + return &Scraper{client: client} +} + +func (s *Scraper) ScrapeBlog(ctx context.Context, blogURL string, selector string) ([]ScrapedArticle, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { return nil, ScrapeError{Message: fmt.Sprintf("failed to create request: %v", err)} } - response, err := client.Do(req) + response, err := s.client.Do(req) if err != nil { return nil, ScrapeError{Message: fmt.Sprintf("failed to fetch page: %v", err)} } diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go index 6f288b4..52fe7e9 100644 --- a/internal/scraper/scraper_test.go +++ b/internal/scraper/scraper_test.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/require" ) +func newTestScraper() *Scraper { + return NewScraper(&http.Client{Timeout: 2 * time.Second}) +} + func TestScrapeBlog(t *testing.T) { html := ` @@ -28,7 +32,7 @@ func TestScrapeBlog(t *testing.T) { })) defer server.Close() - articles, err := ScrapeBlog(context.Background(), server.URL, "article h2 a, .post", 2*time.Second) + articles, err := newTestScraper().ScrapeBlog(context.Background(), server.URL, "article h2 a, .post") require.NoError(t, err, "scrape blog") require.Len(t, articles, 2) require.NotEmpty(t, articles[0].URL) From 61dd7506f1438dbecebb334f7f9742b8f444c493 Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 15:10:33 -0700 Subject: [PATCH 2/7] Add category support for RSS/Atom feed articles Parse and store categories/tags from RSS and Atom feeds, with a new --category flag on the articles command to filter by category. Inspired by Hyaxia/blogwatcher#11 (by @weronikakombat). Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/cli/commands.go | 8 +- internal/controller/controller.go | 11 +- internal/controller/controller_test.go | 29 +++- internal/model/model.go | 1 + internal/rss/rss.go | 2 + internal/rss/rss_test.go | 34 ++++ internal/scanner/scanner.go | 1 + internal/scanner/scanner_test.go | 60 ++++++- internal/storage/database.go | 48 ++++-- internal/storage/database_test.go | 150 +++++++++++++++++- .../migrations/000002_add_categories.down.sql | 18 +++ .../migrations/000002_add_categories.up.sql | 1 + 12 files changed, 337 insertions(+), 26 deletions(-) create mode 100644 internal/storage/migrations/000002_add_categories.down.sql create mode 100644 internal/storage/migrations/000002_add_categories.up.sql diff --git a/internal/cli/commands.go b/internal/cli/commands.go index e356cb6..4c2fac3 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -237,7 +237,7 @@ func newArticlesCommand() *cobra.Command { fmt.Fprintf(os.Stderr, "close db: %v\n", err) } }() - articles, blogNames, err := controller.GetArticles(cmd.Context(), db, showAll, viper.GetString("blog")) + articles, blogNames, err := controller.GetArticles(cmd.Context(), db, showAll, viper.GetString("blog"), viper.GetString("category")) if err != nil { printError(err) return markError(err) @@ -265,6 +265,7 @@ func newArticlesCommand() *cobra.Command { cmd.Flags().BoolP("all", "a", false, "Show all articles (including read)") cmd.Flags().StringP("blog", "b", "", "Filter by blog name") + cmd.Flags().StringP("category", "c", "", "Filter by category") return cmd } @@ -320,7 +321,7 @@ func newReadAllCommand() *cobra.Command { } }() - articles, _, err := controller.GetArticles(cmd.Context(), db, false, blogName) + articles, _, err := controller.GetArticles(cmd.Context(), db, false, blogName, "") if err != nil { printError(err) return markError(err) @@ -425,6 +426,9 @@ func printArticle(article model.Article, blogName string) { if article.PublishedDate != nil { fmt.Printf(" Published: %s\n", article.PublishedDate.Format("2006-01-02")) } + if len(article.Categories) > 0 { + fmt.Printf(" Categories: %s\n", strings.Join(article.Categories, ", ")) + } fmt.Println() } diff --git a/internal/controller/controller.go b/internal/controller/controller.go index cea763a..62b4c53 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -66,7 +66,7 @@ func RemoveBlog(ctx context.Context, db *storage.Database, name string) error { return err } -func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogName string) ([]model.Article, map[int64]string, error) { +func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogName string, category string) ([]model.Article, map[int64]string, error) { var blogID *int64 if blogName != "" { blog, err := db.GetBlogByName(ctx, blogName) @@ -79,7 +79,12 @@ func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogNa blogID = &blog.ID } - articles, err := db.ListArticles(ctx, !showAll, blogID) + var categoryPtr *string + if category != "" { + categoryPtr = &category + } + + articles, err := db.ListArticles(ctx, !showAll, blogID, categoryPtr) if err != nil { return nil, nil, err } @@ -125,7 +130,7 @@ func MarkAllArticlesRead(ctx context.Context, db *storage.Database, blogName str blogID = &blog.ID } - articles, err := db.ListArticles(ctx, true, blogID) + articles, err := db.ListArticles(ctx, true, blogID, nil) if err != nil { return nil, err } diff --git a/internal/controller/controller_test.go b/internal/controller/controller_test.go index 7e6b5c0..5cea5a6 100644 --- a/internal/controller/controller_test.go +++ b/internal/controller/controller_test.go @@ -57,15 +57,40 @@ func TestGetArticlesFilters(t *testing.T) { _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Title", URL: "https://example.com/1"}) require.NoError(t, err, "add article") - articles, blogNames, err := GetArticles(ctx, db, false, "") + articles, blogNames, err := GetArticles(ctx, db, false, "", "") require.NoError(t, err, "get articles") require.Len(t, articles, 1) require.Equal(t, blog.Name, blogNames[blog.ID]) - _, _, err = GetArticles(ctx, db, false, "Missing") + _, _, err = GetArticles(ctx, db, false, "Missing", "") require.Error(t, err, "expected blog not found error") } +func TestGetArticlesFilterByCategory(t *testing.T) { + ctx := context.Background() + db := openTestDB(t) + defer func() { require.NoError(t, db.Close()) }() + + blog, err := AddBlog(ctx, db, "Test", "https://example.com", "", "") + require.NoError(t, err, "add blog") + + _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Go Post", URL: "https://example.com/1", Categories: []string{"Go", "Programming"}}) + require.NoError(t, err, "add article") + _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Rust Post", URL: "https://example.com/2", Categories: []string{"Rust"}}) + require.NoError(t, err, "add article") + + // Filter by Go + articles, _, err := GetArticles(ctx, db, false, "", "Go") + require.NoError(t, err, "get articles by category") + require.Len(t, articles, 1) + require.Equal(t, "Go Post", articles[0].Title) + + // No filter returns all + all, _, err := GetArticles(ctx, db, false, "", "") + require.NoError(t, err, "get all articles") + require.Len(t, all, 2) +} + func openTestDB(t *testing.T) *storage.Database { t.Helper() path := filepath.Join(t.TempDir(), "blogwatcher-cli.db") diff --git a/internal/model/model.go b/internal/model/model.go index dd0a2d8..07d0fee 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -19,4 +19,5 @@ type Article struct { PublishedDate *time.Time DiscoveredDate *time.Time IsRead bool + Categories []string } diff --git a/internal/rss/rss.go b/internal/rss/rss.go index 87021e8..95dd593 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -18,6 +18,7 @@ type FeedArticle struct { Title string URL string PublishedDate *time.Time + Categories []string } type FeedParseError struct { @@ -73,6 +74,7 @@ func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, Title: title, URL: link, PublishedDate: pickPublishedDate(item), + Categories: item.Categories, }) } diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go index 6dcd987..a752b96 100644 --- a/internal/rss/rss_test.go +++ b/internal/rss/rss_test.go @@ -45,6 +45,40 @@ func TestParseFeed(t *testing.T) { require.NotNil(t, articles[0].PublishedDate) } +func TestParseFeedWithCategories(t *testing.T) { + feedWithCategories := ` + + +Example Feed + +Tagged Post +https://example.com/tagged +AI +Machine Learning + + +Plain Post +https://example.com/plain + + +` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, writeErr := w.Write([]byte(feedWithCategories)); writeErr != nil { + http.Error(w, writeErr.Error(), http.StatusInternalServerError) + return + } + })) + defer server.Close() + + articles, err := newTestFetcher().ParseFeed(context.Background(), server.URL) + require.NoError(t, err, "parse feed") + require.Len(t, articles, 2) + + require.Equal(t, []string{"AI", "Machine Learning"}, articles[0].Categories) + require.Nil(t, articles[1].Categories) +} + func TestDiscoverFeedURL(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 5271dc5..1cae9ac 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -216,6 +216,7 @@ func convertFeedArticles(blogID int64, articles []rss.FeedArticle) []model.Artic URL: article.URL, PublishedDate: article.PublishedDate, IsRead: false, + Categories: article.Categories, }) } return result diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go index dcf0f56..374a4cf 100644 --- a/internal/scanner/scanner_test.go +++ b/internal/scanner/scanner_test.go @@ -57,7 +57,7 @@ func TestScanBlogRSS(t *testing.T) { require.Equal(t, 2, result.NewArticles) require.Equal(t, "rss", result.Source) - articles, err := db.ListArticles(ctx, false, nil) + articles, err := db.ListArticles(ctx, false, nil, nil) require.NoError(t, err, "list articles") require.Len(t, articles, 2) } @@ -166,6 +166,64 @@ func TestScanBlogRespectsExistingArticles(t *testing.T) { require.Equal(t, 1, result.NewArticles) } +func TestScanBlogRSSWithCategories(t *testing.T) { + ctx := context.Background() + feedWithCategories := ` + + +Example Feed + +First +https://example.com/1 +Go +Programming + + +Second +https://example.com/2 + + +` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, writeErr := w.Write([]byte(feedWithCategories)); writeErr != nil { + http.Error(w, writeErr.Error(), http.StatusInternalServerError) + return + } + })) + defer server.Close() + + db := openTestDB(t) + defer func() { require.NoError(t, db.Close()) }() + + blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com", FeedURL: server.URL}) + require.NoError(t, err, "add blog") + + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) + require.Equal(t, 2, result.NewArticles) + + articles, err := db.ListArticles(ctx, false, nil, nil) + require.NoError(t, err, "list articles") + require.Len(t, articles, 2) + + // Find the article with categories + var withCat *model.Article + var withoutCat *model.Article + for i := range articles { + if articles[i].Title == "First" { + withCat = &articles[i] + } else { + withoutCat = &articles[i] + } + } + require.NotNil(t, withCat) + require.Equal(t, []string{"Go", "Programming"}, withCat.Categories) + + require.NotNil(t, withoutCat) + require.Nil(t, withoutCat.Categories) +} + func ptrTime(value time.Time) *time.Time { return &value } diff --git a/internal/storage/database.go b/internal/storage/database.go index a2c2ec0..5a01325 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -219,8 +220,8 @@ func (db *Database) RemoveBlog(ctx context.Context, id int64) (bool, error) { func (db *Database) AddArticle(ctx context.Context, article model.Article) (model.Article, error) { result, err := sq.Insert("articles"). - Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read"). - Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead). + Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). + Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, categoriesToString(article.Categories)). RunWith(db.conn). ExecContext(ctx) if err != nil { @@ -244,7 +245,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl } insert := sq.Insert("articles"). - Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read") + Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories") for _, article := range articles { insert = insert.Values( article.BlogID, @@ -253,6 +254,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, + categoriesToString(article.Categories), ) } @@ -271,7 +273,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl } func (db *Database) GetArticle(ctx context.Context, id int64) (*model.Article, error) { - row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read"). + row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). From("articles"). Where(sq.Eq{"id": id}). RunWith(db.conn). @@ -280,7 +282,7 @@ func (db *Database) GetArticle(ctx context.Context, id int64) (*model.Article, e } func (db *Database) GetArticleByURL(ctx context.Context, url string) (*model.Article, error) { - row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read"). + row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). From("articles"). Where(sq.Eq{"url": url}). RunWith(db.conn). @@ -347,8 +349,8 @@ func (db *Database) GetExistingArticleURLs(ctx context.Context, urls []string) ( return result, nil } -func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *int64) ([]model.Article, error) { - query := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read"). +func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *int64, category *string) ([]model.Article, error) { + query := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). From("articles"). OrderBy("discovered_date DESC") @@ -358,6 +360,9 @@ func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *i if blogID != nil { query = query.Where(sq.Eq{"blog_id": *blogID}) } + if category != nil && *category != "" { + query = query.Where(sq.Like{"categories": "%" + *category + "%"}) + } rows, err := query.RunWith(db.conn).QueryContext(ctx) if err != nil { @@ -456,8 +461,9 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, publishedDate sql.NullString discovered sql.NullString isRead bool + categories sql.NullString ) - if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead); err != nil { + if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead, &categories); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -465,11 +471,12 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, } article := &model.Article{ - ID: id, - BlogID: blogID, - Title: title, - URL: url, - IsRead: isRead, + ID: id, + BlogID: blogID, + Title: title, + URL: url, + IsRead: isRead, + Categories: categoriesFromString(categories), } if publishedDate.Valid { if parsed, err := parseTime(publishedDate.String); err == nil { @@ -512,3 +519,18 @@ func nullIfEmpty(value string) *string { } return &value } + +func categoriesToString(categories []string) *string { + if len(categories) == 0 { + return nil + } + s := strings.Join(categories, ",") + return &s +} + +func categoriesFromString(s sql.NullString) []string { + if !s.Valid || s.String == "" { + return nil + } + return strings.Split(s.String, ",") +} diff --git a/internal/storage/database_test.go b/internal/storage/database_test.go index 29f03e4..7aec4d1 100644 --- a/internal/storage/database_test.go +++ b/internal/storage/database_test.go @@ -39,7 +39,7 @@ func TestDatabaseCreatesFileAndCRUD(t *testing.T) { require.NoError(t, err, "add articles bulk") require.Equal(t, 2, count) - list, err := db.ListArticles(ctx, false, nil) + list, err := db.ListArticles(ctx, false, nil, nil) require.NoError(t, err, "list articles") require.Len(t, list, 2) @@ -190,17 +190,17 @@ func TestListArticlesFiltersAndOrdering(t *testing.T) { _, err = db.MarkArticleRead(ctx, first.ID) require.NoError(t, err, "mark read") - all, err := db.ListArticles(ctx, false, nil) + all, err := db.ListArticles(ctx, false, nil, nil) require.NoError(t, err, "list articles") require.Len(t, all, 3) require.Equal(t, second.ID, all[0].ID, "expected newest article first") - unread, err := db.ListArticles(ctx, true, nil) + unread, err := db.ListArticles(ctx, true, nil, nil) require.NoError(t, err, "list unread") require.Len(t, unread, 2) blogID := blogB.ID - filtered, err := db.ListArticles(ctx, false, &blogID) + filtered, err := db.ListArticles(ctx, false, &blogID, nil) require.NoError(t, err, "list by blog") require.Len(t, filtered, 1) require.Equal(t, blogB.ID, filtered[0].BlogID) @@ -231,7 +231,7 @@ func TestBulkInsertDuplicateRollbackAndEmpty(t *testing.T) { _, err = db.AddArticlesBulk(ctx, dupArticles) require.Error(t, err, "expected bulk insert to fail on duplicate url") - articles, err := db.ListArticles(ctx, false, nil) + articles, err := db.ListArticles(ctx, false, nil, nil) require.NoError(t, err, "list articles") require.Len(t, articles, 1, "expected rollback on duplicate") } @@ -269,3 +269,143 @@ func TestLookupHelpers(t *testing.T) { require.NoError(t, err) require.False(t, exists) } + +func TestArticleCategoriesRoundTrip(t *testing.T) { + ctx := context.Background() + db := openTestDB(t) + defer func() { require.NoError(t, db.Close()) }() + + blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"}) + require.NoError(t, err, "add blog") + + // Article with categories + article, err := db.AddArticle(ctx, model.Article{ + BlogID: blog.ID, + Title: "Categorized", + URL: "https://example.com/cat", + Categories: []string{"Go", "Programming"}, + }) + require.NoError(t, err, "add article with categories") + + fetched, err := db.GetArticle(ctx, article.ID) + require.NoError(t, err, "get article") + require.NotNil(t, fetched) + require.Equal(t, []string{"Go", "Programming"}, fetched.Categories) + + // Article without categories + articleNoCat, err := db.AddArticle(ctx, model.Article{ + BlogID: blog.ID, + Title: "No Category", + URL: "https://example.com/nocat", + }) + require.NoError(t, err, "add article without categories") + + fetchedNoCat, err := db.GetArticle(ctx, articleNoCat.ID) + require.NoError(t, err, "get article") + require.NotNil(t, fetchedNoCat) + require.Nil(t, fetchedNoCat.Categories) +} + +func TestListArticlesFilterByCategory(t *testing.T) { + ctx := context.Background() + db := openTestDB(t) + defer func() { require.NoError(t, db.Close()) }() + + blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"}) + require.NoError(t, err, "add blog") + + t1 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) + t2 := time.Date(2024, 1, 1, 11, 0, 0, 0, time.UTC) + t3 := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + + _, err = db.AddArticle(ctx, model.Article{ + BlogID: blog.ID, + Title: "Go Article", + URL: "https://example.com/go", + DiscoveredDate: &t1, + Categories: []string{"Go", "Programming"}, + }) + require.NoError(t, err, "add go article") + + _, err = db.AddArticle(ctx, model.Article{ + BlogID: blog.ID, + Title: "Rust Article", + URL: "https://example.com/rust", + DiscoveredDate: &t2, + Categories: []string{"Rust", "Programming"}, + }) + require.NoError(t, err, "add rust article") + + _, err = db.AddArticle(ctx, model.Article{ + BlogID: blog.ID, + Title: "No Category", + URL: "https://example.com/nocat", + DiscoveredDate: &t3, + }) + require.NoError(t, err, "add no-cat article") + + // Filter by "Go" - should return only the Go article + cat := "Go" + goArticles, err := db.ListArticles(ctx, false, nil, &cat) + require.NoError(t, err, "list by category Go") + require.Len(t, goArticles, 1) + require.Equal(t, "Go Article", goArticles[0].Title) + + // Filter by "Programming" - should return both categorized articles + cat = "Programming" + progArticles, err := db.ListArticles(ctx, false, nil, &cat) + require.NoError(t, err, "list by category Programming") + require.Len(t, progArticles, 2) + + // No filter - should return all 3 + all, err := db.ListArticles(ctx, false, nil, nil) + require.NoError(t, err, "list all") + require.Len(t, all, 3) + + // Empty string category should return all + empty := "" + allEmpty, err := db.ListArticles(ctx, false, nil, &empty) + require.NoError(t, err, "list with empty category") + require.Len(t, allEmpty, 3) +} + +func TestBulkInsertWithCategories(t *testing.T) { + ctx := context.Background() + db := openTestDB(t) + defer func() { require.NoError(t, db.Close()) }() + + blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"}) + require.NoError(t, err, "add blog") + + articles := []model.Article{ + {BlogID: blog.ID, Title: "One", URL: "https://example.com/1", Categories: []string{"AI", "ML"}}, + {BlogID: blog.ID, Title: "Two", URL: "https://example.com/2"}, + } + count, err := db.AddArticlesBulk(ctx, articles) + require.NoError(t, err, "bulk insert") + require.Equal(t, 2, count) + + list, err := db.ListArticles(ctx, false, nil, nil) + require.NoError(t, err, "list articles") + require.Len(t, list, 2) + + // Find the one with categories (order is by discovered_date DESC, both nil so order may vary) + var withCat *model.Article + for i := range list { + if list[i].Title == "One" { + withCat = &list[i] + break + } + } + require.NotNil(t, withCat, "expected article with categories") + require.Equal(t, []string{"AI", "ML"}, withCat.Categories) +} + +func openTestDB(t *testing.T) *Database { + t.Helper() + ctx := context.Background() + path := filepath.Join(t.TempDir(), "blogwatcher-cli.db") + db, err := OpenDatabase(ctx, path) + require.NoError(t, err, "open database") + return db +} diff --git a/internal/storage/migrations/000002_add_categories.down.sql b/internal/storage/migrations/000002_add_categories.down.sql new file mode 100644 index 0000000..0316e9b --- /dev/null +++ b/internal/storage/migrations/000002_add_categories.down.sql @@ -0,0 +1,18 @@ +-- SQLite does not support DROP COLUMN prior to 3.35.0. +-- Recreate the table without the categories column. +CREATE TABLE articles_backup ( + id INTEGER PRIMARY KEY, + blog_id INTEGER NOT NULL, + title TEXT NOT NULL, + url TEXT NOT NULL UNIQUE, + published_date TIMESTAMP, + discovered_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_read BOOLEAN DEFAULT FALSE, + FOREIGN KEY (blog_id) REFERENCES blogs(id) +); + +INSERT INTO articles_backup SELECT id, blog_id, title, url, published_date, discovered_date, is_read FROM articles; + +DROP TABLE articles; + +ALTER TABLE articles_backup RENAME TO articles; diff --git a/internal/storage/migrations/000002_add_categories.up.sql b/internal/storage/migrations/000002_add_categories.up.sql new file mode 100644 index 0000000..d3f8dbc --- /dev/null +++ b/internal/storage/migrations/000002_add_categories.up.sql @@ -0,0 +1 @@ +ALTER TABLE articles ADD COLUMN categories TEXT; From 5795b21b0e800959c5328ee50728966624ec9e8c Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 17:24:12 -0700 Subject: [PATCH 3/7] Fix review findings and add category e2e tests - Only persist discovered feed URL after ParseFeed succeeds - Fix category LIKE query to use delimited matching (prevents "AI" matching "FAIR") - Replace scanner worker pool with errgroup to prevent deadlock - Add categories to GitHub RSS fixture and e2e test for --category filter Co-Authored-By: Claude Opus 4.6 (1M context) --- e2e/e2e_test.go | 4 ++ e2e/expected/11_articles_unread.txt | 3 ++ e2e/expected/12b_articles_filter_category.txt | 13 +++++ e2e/expected/20_articles_all.txt | 3 ++ e2e/testdata/github_blog.rss | 5 ++ go.mod | 1 + internal/scanner/scanner.go | 53 +++++++++++-------- internal/storage/database.go | 9 +++- 8 files changed, 68 insertions(+), 23 deletions(-) create mode 100644 e2e/expected/12b_articles_filter_category.txt diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 0dbd930..455bbc9 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -355,6 +355,10 @@ func TestE2E(t *testing.T) { out = c.ok(t, []string{"articles"}, map[string]string{"blog": "go-blog"}) checkOutput(t, "12_articles_filter_blog", out, baseURL) + // ── Articles filtered by category ── + out = c.ok(t, []string{"articles"}, map[string]string{"category": "Engineering"}) + checkOutput(t, "12b_articles_filter_category", out, baseURL) + // ── Read / unread cycle ── articlesOut := c.ok(t, []string{"articles"}, nil) id := extractFirstID(t, articlesOut) diff --git a/e2e/expected/11_articles_unread.txt b/e2e/expected/11_articles_unread.txt index e8ec68a..af246cb 100644 --- a/e2e/expected/11_articles_unread.txt +++ b/e2e/expected/11_articles_unread.txt @@ -12,11 +12,13 @@ Unread articles (11): Blog: github-blog URL: https://github.blog/news-insights/github-copilot-the-agent-awakens/ Published: 2026-04-01 + Categories: AI, Copilot [ID] [new] How we built the new GitHub Issues Blog: github-blog URL: https://github.blog/engineering/how-we-built-the-new-github-issues/ Published: 2026-03-28 + Categories: Engineering [ID] [new] More powerful Go execution traces Blog: go-blog @@ -36,6 +38,7 @@ Unread articles (11): Blog: github-blog URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/ Published: 2026-04-03 + Categories: Engineering, Performance [ID] [new] Type Construction and Cycle Detection Blog: go-blog diff --git a/e2e/expected/12b_articles_filter_category.txt b/e2e/expected/12b_articles_filter_category.txt new file mode 100644 index 0000000..85d32f9 --- /dev/null +++ b/e2e/expected/12b_articles_filter_category.txt @@ -0,0 +1,13 @@ +Unread articles (2): + + [ID] [new] How we built the new GitHub Issues + Blog: github-blog + URL: https://github.blog/engineering/how-we-built-the-new-github-issues/ + Published: 2026-03-28 + Categories: Engineering + + [ID] [new] The uphill climb of making diff lines performant + Blog: github-blog + URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/ + Published: 2026-04-03 + Categories: Engineering, Performance diff --git a/e2e/expected/20_articles_all.txt b/e2e/expected/20_articles_all.txt index b0b9f6c..1507d43 100644 --- a/e2e/expected/20_articles_all.txt +++ b/e2e/expected/20_articles_all.txt @@ -12,11 +12,13 @@ All articles (11): Blog: github-blog URL: https://github.blog/news-insights/github-copilot-the-agent-awakens/ Published: 2026-04-01 + Categories: AI, Copilot [ID] [read] How we built the new GitHub Issues Blog: github-blog URL: https://github.blog/engineering/how-we-built-the-new-github-issues/ Published: 2026-03-28 + Categories: Engineering [ID] [read] More powerful Go execution traces Blog: go-blog @@ -36,6 +38,7 @@ All articles (11): Blog: github-blog URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/ Published: 2026-04-03 + Categories: Engineering, Performance [ID] [read] Type Construction and Cycle Detection Blog: go-blog diff --git a/e2e/testdata/github_blog.rss b/e2e/testdata/github_blog.rss index 3977029..02a4ab4 100644 --- a/e2e/testdata/github_blog.rss +++ b/e2e/testdata/github_blog.rss @@ -15,6 +15,8 @@ https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/ Fri, 03 Apr 2026 16:00:00 +0000 + Engineering + Performance @@ -22,6 +24,8 @@ https://github.blog/news-insights/github-copilot-the-agent-awakens/ Wed, 01 Apr 2026 12:00:00 +0000 + AI + Copilot @@ -29,6 +33,7 @@ https://github.blog/engineering/how-we-built-the-new-github-issues/ Mon, 28 Mar 2026 10:00:00 +0000 + Engineering diff --git a/go.mod b/go.mod index 8b8cdd8..6738bbf 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 + golang.org/x/sync v0.20.0 modernc.org/sqlite v1.48.1 ) diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 1cae9ac..accdd99 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -6,6 +6,8 @@ import ( "os" "time" + "golang.org/x/sync/errgroup" + "github.com/JulienTant/blogwatcher-cli/internal/model" "github.com/JulienTant/blogwatcher-cli/internal/rss" "github.com/JulienTant/blogwatcher-cli/internal/scraper" @@ -44,10 +46,6 @@ func (s *Scanner) ScanBlog(ctx context.Context, db *storage.Database, blog model } if discovered != "" { feedURL = discovered - blog.FeedURL = discovered - if err := db.UpdateBlog(ctx, blog); err != nil { - return ScanResult{BlogName: blog.Name}, err - } } } @@ -68,6 +66,13 @@ func (s *Scanner) ScanBlog(ctx context.Context, db *storage.Database, blog model } else { articles = convertFeedArticles(blog.ID, feedArticles) source = "rss" + // Persist discovered feed URL only after successful parse. + if blog.FeedURL != feedURL { + blog.FeedURL = feedURL + if err := db.UpdateBlog(ctx, blog); err != nil { + return ScanResult{BlogName: blog.Name}, err + } + } } } else if blog.ScrapeSelector != "" { scrapedArticles, err := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector) @@ -152,14 +157,14 @@ func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, worker } jobs := make(chan job) results := make([]ScanResult, len(blogs)) - errs := make(chan error, workers) + + g, gctx := errgroup.WithContext(ctx) for i := 0; i < workers; i++ { - go func() { - workerDB, err := storage.OpenDatabase(ctx, db.Path()) + g.Go(func() error { + workerDB, err := storage.OpenDatabase(gctx, db.Path()) if err != nil { - errs <- err - return + return err } defer func() { if err := workerDB.Close(); err != nil { @@ -167,26 +172,30 @@ func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, worker } }() for item := range jobs { - result, err := s.ScanBlog(ctx, workerDB, item.Blog) + result, err := s.ScanBlog(gctx, workerDB, item.Blog) if err != nil { - errs <- fmt.Errorf("scan %s: %w", item.Blog.Name, err) - return + return fmt.Errorf("scan %s: %w", item.Blog.Name, err) } results[item.Index] = result } - errs <- nil - }() - } - - for index, blog := range blogs { - jobs <- job{Index: index, Blog: blog} + return nil + }) } - close(jobs) - for i := 0; i < workers; i++ { - if err := <-errs; err != nil { - return nil, err + g.Go(func() error { + defer close(jobs) + for index, blog := range blogs { + select { + case jobs <- job{Index: index, Blog: blog}: + case <-gctx.Done(): + return gctx.Err() + } } + return nil + }) + + if err := g.Wait(); err != nil { + return nil, err } return results, nil diff --git a/internal/storage/database.go b/internal/storage/database.go index 5a01325..d4177ed 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -361,7 +361,14 @@ func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *i query = query.Where(sq.Eq{"blog_id": *blogID}) } if category != nil && *category != "" { - query = query.Where(sq.Like{"categories": "%" + *category + "%"}) + // Categories are stored as comma-separated values. Use delimited + // matching to avoid partial matches (e.g. "AI" matching "FAIR"). + query = query.Where(sq.Or{ + sq.Eq{"categories": *category}, // exact single category + sq.Like{"categories": *category + ",%"}, // starts with + sq.Like{"categories": "%," + *category}, // ends with + sq.Like{"categories": "%," + *category + ",%"}, // middle + }) } rows, err := query.RunWith(db.conn).QueryContext(ctx) From 0b0c115da0870908993b69bb59149d251264a523 Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 17:25:09 -0700 Subject: [PATCH 4/7] Fix test compatibility with Fetcher struct API Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/rss/rss_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go index f7c74a1..dbdf453 100644 --- a/internal/rss/rss_test.go +++ b/internal/rss/rss_test.go @@ -114,7 +114,7 @@ func TestDiscoverFeedURL_XMLContentType(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - feedURL, err := DiscoverFeedURL(context.Background(), server.URL+"/tag/AI/feed/", 2*time.Second) + feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL+"/tag/AI/feed/") require.NoError(t, err) require.Equal(t, server.URL+"/tag/AI/feed/", feedURL, "should return URL directly for feed content-type") } @@ -138,7 +138,7 @@ func TestDiscoverFeedURL_RelSelf(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - feedURL, err := DiscoverFeedURL(context.Background(), server.URL, 2*time.Second) + feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL) require.NoError(t, err) require.Equal(t, server.URL+"/my-feed.xml", feedURL, "should discover feed from rel=self link") } From 10c9812938f60ea7110229b33765fb6ca3e9db07 Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 17:28:36 -0700 Subject: [PATCH 5/7] Store categories as JSON array instead of comma-separated string - Use json.Marshal/Unmarshal for categories serialization - Query with json_each() for exact element matching instead of LIKE Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/storage/database.go | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/internal/storage/database.go b/internal/storage/database.go index d4177ed..ae43def 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -3,11 +3,11 @@ package storage import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" "path/filepath" - "strings" "time" sq "github.com/Masterminds/squirrel" @@ -221,7 +221,7 @@ func (db *Database) RemoveBlog(ctx context.Context, id int64) (bool, error) { func (db *Database) AddArticle(ctx context.Context, article model.Article) (model.Article, error) { result, err := sq.Insert("articles"). Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). - Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, categoriesToString(article.Categories)). + Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, categoriesToJSON(article.Categories)). RunWith(db.conn). ExecContext(ctx) if err != nil { @@ -254,7 +254,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, - categoriesToString(article.Categories), + categoriesToJSON(article.Categories), ) } @@ -361,14 +361,9 @@ func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *i query = query.Where(sq.Eq{"blog_id": *blogID}) } if category != nil && *category != "" { - // Categories are stored as comma-separated values. Use delimited - // matching to avoid partial matches (e.g. "AI" matching "FAIR"). - query = query.Where(sq.Or{ - sq.Eq{"categories": *category}, // exact single category - sq.Like{"categories": *category + ",%"}, // starts with - sq.Like{"categories": "%," + *category}, // ends with - sq.Like{"categories": "%," + *category + ",%"}, // middle - }) + // Categories are stored as a JSON string array. Use json_each() + // for exact element matching. + query = query.Where("EXISTS (SELECT 1 FROM json_each(categories) WHERE json_each.value = ?)", *category) } rows, err := query.RunWith(db.conn).QueryContext(ctx) @@ -483,7 +478,7 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, Title: title, URL: url, IsRead: isRead, - Categories: categoriesFromString(categories), + Categories: categoriesFromJSON(categories), } if publishedDate.Valid { if parsed, err := parseTime(publishedDate.String); err == nil { @@ -527,17 +522,25 @@ func nullIfEmpty(value string) *string { return &value } -func categoriesToString(categories []string) *string { +func categoriesToJSON(categories []string) *string { if len(categories) == 0 { return nil } - s := strings.Join(categories, ",") + b, err := json.Marshal(categories) + if err != nil { + return nil + } + s := string(b) return &s } -func categoriesFromString(s sql.NullString) []string { +func categoriesFromJSON(s sql.NullString) []string { if !s.Valid || s.String == "" { return nil } - return strings.Split(s.String, ",") + var cats []string + if err := json.Unmarshal([]byte(s.String), &cats); err != nil { + return nil + } + return cats } From d9ecfa552a6b9f194a7d1f27191438f892463cb0 Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 17:31:17 -0700 Subject: [PATCH 6/7] Case-insensitive category filtering with LOWER() + test Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/storage/database.go | 2 +- internal/storage/database_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/internal/storage/database.go b/internal/storage/database.go index ae43def..ac629dd 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -363,7 +363,7 @@ func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *i if category != nil && *category != "" { // Categories are stored as a JSON string array. Use json_each() // for exact element matching. - query = query.Where("EXISTS (SELECT 1 FROM json_each(categories) WHERE json_each.value = ?)", *category) + query = query.Where("EXISTS (SELECT 1 FROM json_each(categories) WHERE LOWER(json_each.value) = LOWER(?))", *category) } rows, err := query.RunWith(db.conn).QueryContext(ctx) diff --git a/internal/storage/database_test.go b/internal/storage/database_test.go index 7aec4d1..531da09 100644 --- a/internal/storage/database_test.go +++ b/internal/storage/database_test.go @@ -362,6 +362,19 @@ func TestListArticlesFilterByCategory(t *testing.T) { require.NoError(t, err, "list all") require.Len(t, all, 3) + // Case-insensitive match - "go" should match "Go" + cat = "go" + goLower, err := db.ListArticles(ctx, false, nil, &cat) + require.NoError(t, err, "list by category go (lowercase)") + require.Len(t, goLower, 1) + require.Equal(t, "Go Article", goLower[0].Title) + + // Case-insensitive match - "PROGRAMMING" should match "Programming" + cat = "PROGRAMMING" + progUpper, err := db.ListArticles(ctx, false, nil, &cat) + require.NoError(t, err, "list by category PROGRAMMING (uppercase)") + require.Len(t, progUpper, 2) + // Empty string category should return all empty := "" allEmpty, err := db.ListArticles(ctx, false, nil, &empty) From cea7563cecb5c515bd663d3f3305024406f4e7d1 Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 17:43:28 -0700 Subject: [PATCH 7/7] Propagate JSON serialization errors in category helpers Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/storage/database.go | 38 +++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/internal/storage/database.go b/internal/storage/database.go index ac629dd..e9da74c 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -219,9 +219,13 @@ func (db *Database) RemoveBlog(ctx context.Context, id int64) (bool, error) { // Article operations func (db *Database) AddArticle(ctx context.Context, article model.Article) (model.Article, error) { + cats, err := categoriesToJSON(article.Categories) + if err != nil { + return article, err + } result, err := sq.Insert("articles"). Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories"). - Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, categoriesToJSON(article.Categories)). + Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, cats). RunWith(db.conn). ExecContext(ctx) if err != nil { @@ -247,6 +251,13 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl insert := sq.Insert("articles"). Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories") for _, article := range articles { + cats, err := categoriesToJSON(article.Categories) + if err != nil { + if rerr := tx.Rollback(); rerr != nil { + fmt.Fprintf(os.Stderr, "rollback: %v\n", rerr) + } + return 0, err + } insert = insert.Values( article.BlogID, article.Title, @@ -254,7 +265,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, - categoriesToJSON(article.Categories), + cats, ) } @@ -472,13 +483,18 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, return nil, err } + cats, err := categoriesFromJSON(categories) + if err != nil { + return nil, err + } + article := &model.Article{ ID: id, BlogID: blogID, Title: title, URL: url, IsRead: isRead, - Categories: categoriesFromJSON(categories), + Categories: cats, } if publishedDate.Valid { if parsed, err := parseTime(publishedDate.String); err == nil { @@ -522,25 +538,25 @@ func nullIfEmpty(value string) *string { return &value } -func categoriesToJSON(categories []string) *string { +func categoriesToJSON(categories []string) (*string, error) { if len(categories) == 0 { - return nil + return nil, nil } b, err := json.Marshal(categories) if err != nil { - return nil + return nil, fmt.Errorf("marshal categories: %w", err) } s := string(b) - return &s + return &s, nil } -func categoriesFromJSON(s sql.NullString) []string { +func categoriesFromJSON(s sql.NullString) ([]string, error) { if !s.Valid || s.String == "" { - return nil + return nil, nil } var cats []string if err := json.Unmarshal([]byte(s.String), &cats); err != nil { - return nil + return nil, fmt.Errorf("unmarshal categories: %w", err) } - return cats + return cats, nil }