diff --git a/AGENTS.md b/AGENTS.md
index d6c29cf..855e854 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -1,3 +1,7 @@
+### Before speaking
+
+Never guess. Ground every claim in actual files, docs, or command output. If you're unsure about a dependency's size, run `go get` and check. If you're unsure about an API, read the source or docs. If you're unsure about behavior, write a test. State what you verified and how — don't speculate.
+
### While developing
#### Tests, tests, tests!
diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go
index 29caf9b..455bbc9 100644
--- a/e2e/e2e_test.go
+++ b/e2e/e2e_test.go
@@ -57,11 +57,11 @@ func (c *cliOpts) run(t *testing.T, args []string, opts map[string]string) (stdo
var cmdArgs []string
var extraEnv []string
- // DB path goes before the subcommand (persistent flag).
+ // Persistent flags / env vars.
if c.mode == "flags" {
- cmdArgs = append(cmdArgs, "--db", c.dbPath)
+ cmdArgs = append(cmdArgs, "--db", c.dbPath, "--unsafe-client")
} else {
- extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath)
+ extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath, "BLOGWATCHER_UNSAFE_CLIENT=true")
}
// Positional args first (includes the subcommand name).
@@ -355,6 +355,10 @@ func TestE2E(t *testing.T) {
out = c.ok(t, []string{"articles"}, map[string]string{"blog": "go-blog"})
checkOutput(t, "12_articles_filter_blog", out, baseURL)
+ // ── Articles filtered by category ──
+ out = c.ok(t, []string{"articles"}, map[string]string{"category": "Engineering"})
+ checkOutput(t, "12b_articles_filter_category", out, baseURL)
+
// ── Read / unread cycle ──
articlesOut := c.ok(t, []string{"articles"}, nil)
id := extractFirstID(t, articlesOut)
@@ -402,6 +406,28 @@ func TestE2E(t *testing.T) {
}
}
+func TestSSRFProtection(t *testing.T) {
+ baseURL := startTestServer(t)
+ dbPath := filepath.Join(t.TempDir(), "test.db")
+
+ // Add a blog pointing to the loopback test server WITHOUT --unsafe-client.
+ // The add command doesn't fetch, so it should succeed.
+ cmd := exec.CommandContext(context.Background(), binaryPath,
+ "--db", dbPath, "add", "test-blog", baseURL+"/go/",
+ "--feed-url", baseURL+"/go/feed.atom")
+ cmd.Env = append(os.Environ(), "NO_COLOR=1")
+ out, err := cmd.CombinedOutput()
+ require.NoError(t, err, "add should succeed: %s", string(out))
+
+ // Scan WITHOUT --unsafe-client — the safe client should block loopback and fail.
+ cmd = exec.CommandContext(context.Background(), binaryPath,
+ "--db", dbPath, "scan")
+ cmd.Env = append(os.Environ(), "NO_COLOR=1")
+ out, err = cmd.CombinedOutput()
+ require.Error(t, err, "scan should fail when SSRF protection blocks loopback")
+ require.Contains(t, string(out), "is not authorized", "expected SSRF error message")
+}
+
func extractFirstID(t *testing.T, output string) string {
t.Helper()
re := regexp.MustCompile(`\[(\d+)\]`)
diff --git a/e2e/expected/11_articles_unread.txt b/e2e/expected/11_articles_unread.txt
index e8ec68a..af246cb 100644
--- a/e2e/expected/11_articles_unread.txt
+++ b/e2e/expected/11_articles_unread.txt
@@ -12,11 +12,13 @@ Unread articles (11):
Blog: github-blog
URL: https://github.blog/news-insights/github-copilot-the-agent-awakens/
Published: 2026-04-01
+ Categories: AI, Copilot
[ID] [new] How we built the new GitHub Issues
Blog: github-blog
URL: https://github.blog/engineering/how-we-built-the-new-github-issues/
Published: 2026-03-28
+ Categories: Engineering
[ID] [new] More powerful Go execution traces
Blog: go-blog
@@ -36,6 +38,7 @@ Unread articles (11):
Blog: github-blog
URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/
Published: 2026-04-03
+ Categories: Engineering, Performance
[ID] [new] Type Construction and Cycle Detection
Blog: go-blog
diff --git a/e2e/expected/12b_articles_filter_category.txt b/e2e/expected/12b_articles_filter_category.txt
new file mode 100644
index 0000000..85d32f9
--- /dev/null
+++ b/e2e/expected/12b_articles_filter_category.txt
@@ -0,0 +1,13 @@
+Unread articles (2):
+
+ [ID] [new] How we built the new GitHub Issues
+ Blog: github-blog
+ URL: https://github.blog/engineering/how-we-built-the-new-github-issues/
+ Published: 2026-03-28
+ Categories: Engineering
+
+ [ID] [new] The uphill climb of making diff lines performant
+ Blog: github-blog
+ URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/
+ Published: 2026-04-03
+ Categories: Engineering, Performance
diff --git a/e2e/expected/20_articles_all.txt b/e2e/expected/20_articles_all.txt
index b0b9f6c..1507d43 100644
--- a/e2e/expected/20_articles_all.txt
+++ b/e2e/expected/20_articles_all.txt
@@ -12,11 +12,13 @@ All articles (11):
Blog: github-blog
URL: https://github.blog/news-insights/github-copilot-the-agent-awakens/
Published: 2026-04-01
+ Categories: AI, Copilot
[ID] [read] How we built the new GitHub Issues
Blog: github-blog
URL: https://github.blog/engineering/how-we-built-the-new-github-issues/
Published: 2026-03-28
+ Categories: Engineering
[ID] [read] More powerful Go execution traces
Blog: go-blog
@@ -36,6 +38,7 @@ All articles (11):
Blog: github-blog
URL: https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/
Published: 2026-04-03
+ Categories: Engineering, Performance
[ID] [read] Type Construction and Cycle Detection
Blog: go-blog
diff --git a/e2e/testdata/github_blog.rss b/e2e/testdata/github_blog.rss
index 3977029..02a4ab4 100644
--- a/e2e/testdata/github_blog.rss
+++ b/e2e/testdata/github_blog.rss
@@ -15,6 +15,8 @@
https://github.blog/engineering/the-uphill-climb-of-making-diff-lines-performant/
Fri, 03 Apr 2026 16:00:00 +0000
+ Engineering
+ Performance
-
@@ -22,6 +24,8 @@
https://github.blog/news-insights/github-copilot-the-agent-awakens/
Wed, 01 Apr 2026 12:00:00 +0000
+ AI
+ Copilot
-
@@ -29,6 +33,7 @@
https://github.blog/engineering/how-we-built-the-new-github-issues/
Mon, 28 Mar 2026 10:00:00 +0000
+ Engineering
diff --git a/go.mod b/go.mod
index c2e1ea3..6738bbf 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module github.com/JulienTant/blogwatcher-cli
go 1.26.1
require (
+ github.com/DataDog/go-secure-sdk v0.0.7
github.com/Masterminds/squirrel v1.5.4
github.com/PuerkitoBio/goquery v1.12.0
github.com/fatih/color v1.19.0
@@ -11,6 +12,7 @@ require (
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
+ golang.org/x/sync v0.20.0
modernc.org/sqlite v1.48.1
)
@@ -44,7 +46,6 @@ require (
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.43.0 // indirect
- gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
diff --git a/go.sum b/go.sum
index d244cad..83d1bb7 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/DataDog/go-secure-sdk v0.0.7 h1:FJYJPXXZmxOC0RHWzhsZ3PKtUaF+IgC6fDc/PbEoH6c=
+github.com/DataDog/go-secure-sdk v0.0.7/go.mod h1:/fIdMM7LMp7KGouxN9Cnv3UP0NdfP+XZ5R8qyU3D8qk=
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/PuerkitoBio/goquery v1.12.0 h1:pAcL4g3WRXekcB9AU/y1mbKez2dbY2AajVhtkO8RIBo=
@@ -21,6 +23,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
+github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
+github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -66,8 +70,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
@@ -169,8 +173,8 @@ golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
diff --git a/internal/cli/commands.go b/internal/cli/commands.go
index 3e4a504..4c2fac3 100644
--- a/internal/cli/commands.go
+++ b/internal/cli/commands.go
@@ -3,21 +3,34 @@ package cli
import (
"bufio"
"fmt"
+ "net/http"
"os"
"strconv"
"strings"
"time"
+ "github.com/DataDog/go-secure-sdk/net/httpclient"
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/JulienTant/blogwatcher-cli/internal/controller"
"github.com/JulienTant/blogwatcher-cli/internal/model"
+ "github.com/JulienTant/blogwatcher-cli/internal/rss"
"github.com/JulienTant/blogwatcher-cli/internal/scanner"
+ "github.com/JulienTant/blogwatcher-cli/internal/scraper"
"github.com/JulienTant/blogwatcher-cli/internal/storage"
)
+const httpTimeout = 30 * time.Second
+
+func newHTTPClient() *http.Client {
+ if viper.GetBool("unsafe-client") {
+ return httpclient.UnSafe(httpclient.WithTimeout(httpTimeout))
+ }
+ return httpclient.Safe(httpclient.WithTimeout(httpTimeout))
+}
+
func newAddCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "add ",
@@ -148,8 +161,11 @@ func newScanCommand() *cobra.Command {
}
}()
+ client := newHTTPClient()
+ sc := scanner.NewScanner(rss.NewFetcher(client), scraper.NewScraper(client))
+
if len(args) == 1 {
- result, err := scanner.ScanBlogByName(cmd.Context(), db, args[0])
+ result, err := sc.ScanBlogByName(cmd.Context(), db, args[0])
if err != nil {
return err
}
@@ -173,7 +189,7 @@ func newScanCommand() *cobra.Command {
if !silent {
cprintf([]color.Attribute{color.FgCyan}, "Scanning %d blog(s)...\n\n", len(blogs))
}
- results, err := scanner.ScanAllBlogs(cmd.Context(), db, workers)
+ results, err := sc.ScanAllBlogs(cmd.Context(), db, workers)
if err != nil {
return err
}
@@ -221,7 +237,7 @@ func newArticlesCommand() *cobra.Command {
fmt.Fprintf(os.Stderr, "close db: %v\n", err)
}
}()
- articles, blogNames, err := controller.GetArticles(cmd.Context(), db, showAll, viper.GetString("blog"))
+ articles, blogNames, err := controller.GetArticles(cmd.Context(), db, showAll, viper.GetString("blog"), viper.GetString("category"))
if err != nil {
printError(err)
return markError(err)
@@ -249,6 +265,7 @@ func newArticlesCommand() *cobra.Command {
cmd.Flags().BoolP("all", "a", false, "Show all articles (including read)")
cmd.Flags().StringP("blog", "b", "", "Filter by blog name")
+ cmd.Flags().StringP("category", "c", "", "Filter by category")
return cmd
}
@@ -304,7 +321,7 @@ func newReadAllCommand() *cobra.Command {
}
}()
- articles, _, err := controller.GetArticles(cmd.Context(), db, false, blogName)
+ articles, _, err := controller.GetArticles(cmd.Context(), db, false, blogName, "")
if err != nil {
printError(err)
return markError(err)
@@ -385,10 +402,6 @@ func printScanResult(result scanner.ScanResult) {
statusColor = []color.Attribute{color.FgGreen}
}
cprintf([]color.Attribute{color.FgWhite, color.Bold}, " %s\n", result.BlogName)
- if result.Error != "" {
- cprintfErr(color.FgRed, " Error: %s\n", result.Error)
- return
- }
if result.Source == "none" {
cprintln([]color.Attribute{color.FgYellow}, " No feed or scraper configured")
return
@@ -413,6 +426,9 @@ func printArticle(article model.Article, blogName string) {
if article.PublishedDate != nil {
fmt.Printf(" Published: %s\n", article.PublishedDate.Format("2006-01-02"))
}
+ if len(article.Categories) > 0 {
+ fmt.Printf(" Categories: %s\n", strings.Join(article.Categories, ", "))
+ }
fmt.Println()
}
diff --git a/internal/cli/root.go b/internal/cli/root.go
index f822f7e..161840a 100644
--- a/internal/cli/root.go
+++ b/internal/cli/root.go
@@ -24,6 +24,7 @@ func NewRootCommand() *cobra.Command {
rootCmd.SetVersionTemplate("{{.Version}}\n")
rootCmd.PersistentFlags().String("db", "", "Path to the SQLite database file (default: ~/.blogwatcher-cli/blogwatcher-cli.db)")
+ rootCmd.PersistentFlags().Bool("unsafe-client", false, "Disable SSRF protection (allow requests to private/loopback IPs)")
rootCmd.AddCommand(newAddCommand())
rootCmd.AddCommand(newRemoveCommand())
diff --git a/internal/controller/controller.go b/internal/controller/controller.go
index cea763a..62b4c53 100644
--- a/internal/controller/controller.go
+++ b/internal/controller/controller.go
@@ -66,7 +66,7 @@ func RemoveBlog(ctx context.Context, db *storage.Database, name string) error {
return err
}
-func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogName string) ([]model.Article, map[int64]string, error) {
+func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogName string, category string) ([]model.Article, map[int64]string, error) {
var blogID *int64
if blogName != "" {
blog, err := db.GetBlogByName(ctx, blogName)
@@ -79,7 +79,12 @@ func GetArticles(ctx context.Context, db *storage.Database, showAll bool, blogNa
blogID = &blog.ID
}
- articles, err := db.ListArticles(ctx, !showAll, blogID)
+ var categoryPtr *string
+ if category != "" {
+ categoryPtr = &category
+ }
+
+ articles, err := db.ListArticles(ctx, !showAll, blogID, categoryPtr)
if err != nil {
return nil, nil, err
}
@@ -125,7 +130,7 @@ func MarkAllArticlesRead(ctx context.Context, db *storage.Database, blogName str
blogID = &blog.ID
}
- articles, err := db.ListArticles(ctx, true, blogID)
+ articles, err := db.ListArticles(ctx, true, blogID, nil)
if err != nil {
return nil, err
}
diff --git a/internal/controller/controller_test.go b/internal/controller/controller_test.go
index 7e6b5c0..5cea5a6 100644
--- a/internal/controller/controller_test.go
+++ b/internal/controller/controller_test.go
@@ -57,15 +57,40 @@ func TestGetArticlesFilters(t *testing.T) {
_, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Title", URL: "https://example.com/1"})
require.NoError(t, err, "add article")
- articles, blogNames, err := GetArticles(ctx, db, false, "")
+ articles, blogNames, err := GetArticles(ctx, db, false, "", "")
require.NoError(t, err, "get articles")
require.Len(t, articles, 1)
require.Equal(t, blog.Name, blogNames[blog.ID])
- _, _, err = GetArticles(ctx, db, false, "Missing")
+ _, _, err = GetArticles(ctx, db, false, "Missing", "")
require.Error(t, err, "expected blog not found error")
}
+func TestGetArticlesFilterByCategory(t *testing.T) {
+ ctx := context.Background()
+ db := openTestDB(t)
+ defer func() { require.NoError(t, db.Close()) }()
+
+ blog, err := AddBlog(ctx, db, "Test", "https://example.com", "", "")
+ require.NoError(t, err, "add blog")
+
+ _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Go Post", URL: "https://example.com/1", Categories: []string{"Go", "Programming"}})
+ require.NoError(t, err, "add article")
+ _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "Rust Post", URL: "https://example.com/2", Categories: []string{"Rust"}})
+ require.NoError(t, err, "add article")
+
+ // Filter by Go
+ articles, _, err := GetArticles(ctx, db, false, "", "Go")
+ require.NoError(t, err, "get articles by category")
+ require.Len(t, articles, 1)
+ require.Equal(t, "Go Post", articles[0].Title)
+
+ // No filter returns all
+ all, _, err := GetArticles(ctx, db, false, "", "")
+ require.NoError(t, err, "get all articles")
+ require.Len(t, all, 2)
+}
+
func openTestDB(t *testing.T) *storage.Database {
t.Helper()
path := filepath.Join(t.TempDir(), "blogwatcher-cli.db")
diff --git a/internal/model/model.go b/internal/model/model.go
index dd0a2d8..07d0fee 100644
--- a/internal/model/model.go
+++ b/internal/model/model.go
@@ -19,4 +19,5 @@ type Article struct {
PublishedDate *time.Time
DiscoveredDate *time.Time
IsRead bool
+ Categories []string
}
diff --git a/internal/rss/rss.go b/internal/rss/rss.go
index 3dd7b85..fefa98f 100644
--- a/internal/rss/rss.go
+++ b/internal/rss/rss.go
@@ -19,6 +19,7 @@ type FeedArticle struct {
Title string
URL string
PublishedDate *time.Time
+ Categories []string
}
type FeedParseError struct {
@@ -29,13 +30,22 @@ func (e FeedParseError) Error() string {
return e.Message
}
-func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]FeedArticle, error) {
- client := &http.Client{Timeout: timeout}
+// Fetcher fetches and parses RSS/Atom feeds.
+type Fetcher struct {
+ client *http.Client
+}
+
+// NewFetcher creates a Fetcher with the given HTTP client.
+func NewFetcher(client *http.Client) *Fetcher {
+ return &Fetcher{client: client}
+}
+
+func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil)
if err != nil {
return nil, FeedParseError{Message: fmt.Sprintf("failed to create request: %v", err)}
}
- response, err := client.Do(req)
+ response, err := f.client.Do(req)
if err != nil {
return nil, FeedParseError{Message: fmt.Sprintf("failed to fetch feed: %v", err)}
}
@@ -65,19 +75,19 @@ func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]Fe
Title: title,
URL: link,
PublishedDate: pickPublishedDate(item),
+ Categories: item.Categories,
})
}
return articles, nil
}
-func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) (string, error) {
- client := &http.Client{Timeout: timeout}
+func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil)
if err != nil {
return "", nil
}
- response, err := client.Do(req)
+ response, err := f.client.Do(req)
if err != nil {
return "", nil
}
@@ -153,7 +163,7 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration)
if resolved == "" {
continue
}
- ok, err := isValidFeed(ctx, resolved, timeout)
+ ok, err := f.isValidFeed(ctx, resolved)
if err == nil && ok {
return resolved, nil
}
@@ -162,13 +172,12 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration)
return "", nil
}
-func isValidFeed(ctx context.Context, feedURL string, timeout time.Duration) (bool, error) {
- client := &http.Client{Timeout: timeout}
+func (f *Fetcher) isValidFeed(ctx context.Context, feedURL string) (bool, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil)
if err != nil {
return false, err
}
- response, err := client.Do(req)
+ response, err := f.client.Do(req)
if err != nil {
return false, err
}
diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go
index 69e8a22..dbdf453 100644
--- a/internal/rss/rss_test.go
+++ b/internal/rss/rss_test.go
@@ -26,6 +26,10 @@ const sampleFeed = `
`
+func newTestFetcher() *Fetcher {
+ return NewFetcher(&http.Client{Timeout: 2 * time.Second})
+}
+
func TestParseFeed(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil {
@@ -35,12 +39,46 @@ func TestParseFeed(t *testing.T) {
}))
defer server.Close()
- articles, err := ParseFeed(context.Background(), server.URL, 2*time.Second)
+ articles, err := newTestFetcher().ParseFeed(context.Background(), server.URL)
require.NoError(t, err, "parse feed")
require.Len(t, articles, 2)
require.NotNil(t, articles[0].PublishedDate)
}
+func TestParseFeedWithCategories(t *testing.T) {
+ feedWithCategories := `
+
+
+Example Feed
+-
+Tagged Post
+https://example.com/tagged
+AI
+Machine Learning
+
+-
+Plain Post
+https://example.com/plain
+
+
+`
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if _, writeErr := w.Write([]byte(feedWithCategories)); writeErr != nil {
+ http.Error(w, writeErr.Error(), http.StatusInternalServerError)
+ return
+ }
+ }))
+ defer server.Close()
+
+ articles, err := newTestFetcher().ParseFeed(context.Background(), server.URL)
+ require.NoError(t, err, "parse feed")
+ require.Len(t, articles, 2)
+
+ require.Equal(t, []string{"AI", "Machine Learning"}, articles[0].Categories)
+ require.Nil(t, articles[1].Categories)
+}
+
func TestDiscoverFeedURL(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
@@ -58,7 +96,7 @@ func TestDiscoverFeedURL(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()
- feedURL, err := DiscoverFeedURL(context.Background(), server.URL, 2*time.Second)
+ feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL)
require.NoError(t, err, "discover feed")
require.NotEmpty(t, feedURL, "expected feed url")
}
@@ -76,7 +114,7 @@ func TestDiscoverFeedURL_XMLContentType(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()
- feedURL, err := DiscoverFeedURL(context.Background(), server.URL+"/tag/AI/feed/", 2*time.Second)
+ feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL+"/tag/AI/feed/")
require.NoError(t, err)
require.Equal(t, server.URL+"/tag/AI/feed/", feedURL, "should return URL directly for feed content-type")
}
@@ -100,7 +138,7 @@ func TestDiscoverFeedURL_RelSelf(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()
- feedURL, err := DiscoverFeedURL(context.Background(), server.URL, 2*time.Second)
+ feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL)
require.NoError(t, err)
require.Equal(t, server.URL+"/my-feed.xml", feedURL, "should discover feed from rel=self link")
}
diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go
index 7997588..accdd99 100644
--- a/internal/scanner/scanner.go
+++ b/internal/scanner/scanner.go
@@ -6,6 +6,8 @@ import (
"os"
"time"
+ "golang.org/x/sync/errgroup"
+
"github.com/JulienTant/blogwatcher-cli/internal/model"
"github.com/JulienTant/blogwatcher-cli/internal/rss"
"github.com/JulienTant/blogwatcher-cli/internal/scraper"
@@ -17,50 +19,68 @@ type ScanResult struct {
NewArticles int
TotalFound int
Source string
- Error string
}
-func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanResult {
+// Scanner orchestrates blog scanning using a Fetcher and Scraper.
+type Scanner struct {
+ fetcher *rss.Fetcher
+ scraper *scraper.Scraper
+}
+
+// NewScanner creates a Scanner with the given fetcher and scraper.
+func NewScanner(fetcher *rss.Fetcher, scraper *scraper.Scraper) *Scanner {
+ return &Scanner{fetcher: fetcher, scraper: scraper}
+}
+
+func (s *Scanner) ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) (ScanResult, error) {
var (
articles []model.Article
source = "none"
- errText string
)
feedURL := blog.FeedURL
if feedURL == "" {
- if discovered, err := rss.DiscoverFeedURL(ctx, blog.URL, 30*time.Second); err == nil && discovered != "" {
+ discovered, err := s.fetcher.DiscoverFeedURL(ctx, blog.URL)
+ if err != nil {
+ return ScanResult{BlogName: blog.Name}, err
+ }
+ if discovered != "" {
feedURL = discovered
- blog.FeedURL = discovered
- if err := db.UpdateBlog(ctx, blog); err != nil {
- fmt.Fprintf(os.Stderr, "update blog: %v\n", err)
- }
}
}
if feedURL != "" {
- feedArticles, err := rss.ParseFeed(ctx, feedURL, 30*time.Second)
+ feedArticles, err := s.fetcher.ParseFeed(ctx, feedURL)
if err != nil {
- errText = err.Error()
+ // If there's a scraper fallback, try it before giving up.
+ if blog.ScrapeSelector == "" {
+ return ScanResult{BlogName: blog.Name}, err
+ }
+ // Try scraper as fallback.
+ scrapedArticles, scrapeErr := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector)
+ if scrapeErr != nil {
+ return ScanResult{BlogName: blog.Name}, fmt.Errorf("RSS: %w; Scraper: %w", err, scrapeErr)
+ }
+ articles = convertScrapedArticles(blog.ID, scrapedArticles)
+ source = "scraper"
} else {
articles = convertFeedArticles(blog.ID, feedArticles)
source = "rss"
+ // Persist discovered feed URL only after successful parse.
+ if blog.FeedURL != feedURL {
+ blog.FeedURL = feedURL
+ if err := db.UpdateBlog(ctx, blog); err != nil {
+ return ScanResult{BlogName: blog.Name}, err
+ }
+ }
}
- }
-
- if len(articles) == 0 && blog.ScrapeSelector != "" {
- scrapedArticles, err := scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector, 30*time.Second)
+ } else if blog.ScrapeSelector != "" {
+ scrapedArticles, err := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector)
if err != nil {
- if errText != "" {
- errText = fmt.Sprintf("RSS: %s; Scraper: %s", errText, err.Error())
- } else {
- errText = err.Error()
- }
- } else {
- articles = convertScrapedArticles(blog.ID, scrapedArticles)
- source = "scraper"
- errText = ""
+ return ScanResult{BlogName: blog.Name}, err
}
+ articles = convertScrapedArticles(blog.ID, scrapedArticles)
+ source = "scraper"
}
seenURLs := make(map[string]struct{})
@@ -80,7 +100,7 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe
existing, err := db.GetExistingArticleURLs(ctx, urlList)
if err != nil {
- errText = err.Error()
+ return ScanResult{BlogName: blog.Name}, err
}
discoveredAt := time.Now()
@@ -97,14 +117,13 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe
if len(newArticles) > 0 {
count, err := db.AddArticlesBulk(ctx, newArticles)
if err != nil {
- errText = err.Error()
- } else {
- newCount = count
+ return ScanResult{BlogName: blog.Name}, err
}
+ newCount = count
}
if err := db.UpdateBlogLastScanned(ctx, blog.ID, time.Now()); err != nil {
- fmt.Fprintf(os.Stderr, "update last scanned: %v\n", err)
+ return ScanResult{BlogName: blog.Name}, err
}
return ScanResult{
@@ -112,11 +131,10 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe
NewArticles: newCount,
TotalFound: len(seenURLs),
Source: source,
- Error: errText,
- }
+ }, nil
}
-func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) {
+func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) {
blogs, err := db.ListBlogs(ctx)
if err != nil {
return nil, err
@@ -124,7 +142,11 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca
if workers <= 1 {
results := make([]ScanResult, 0, len(blogs))
for _, blog := range blogs {
- results = append(results, ScanBlog(ctx, db, blog))
+ result, err := s.ScanBlog(ctx, db, blog)
+ if err != nil {
+ return nil, fmt.Errorf("scan %s: %w", blog.Name, err)
+ }
+ results = append(results, result)
}
return results, nil
}
@@ -135,14 +157,14 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca
}
jobs := make(chan job)
results := make([]ScanResult, len(blogs))
- errs := make(chan error, workers)
+
+ g, gctx := errgroup.WithContext(ctx)
for i := 0; i < workers; i++ {
- go func() {
- workerDB, err := storage.OpenDatabase(ctx, db.Path())
+ g.Go(func() error {
+ workerDB, err := storage.OpenDatabase(gctx, db.Path())
if err != nil {
- errs <- err
- return
+ return err
}
defer func() {
if err := workerDB.Close(); err != nil {
@@ -150,27 +172,36 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca
}
}()
for item := range jobs {
- results[item.Index] = ScanBlog(ctx, workerDB, item.Blog)
+ result, err := s.ScanBlog(gctx, workerDB, item.Blog)
+ if err != nil {
+ return fmt.Errorf("scan %s: %w", item.Blog.Name, err)
+ }
+ results[item.Index] = result
}
- errs <- nil
- }()
- }
-
- for index, blog := range blogs {
- jobs <- job{Index: index, Blog: blog}
+ return nil
+ })
}
- close(jobs)
- for i := 0; i < workers; i++ {
- if err := <-errs; err != nil {
- return nil, err
+ g.Go(func() error {
+ defer close(jobs)
+ for index, blog := range blogs {
+ select {
+ case jobs <- job{Index: index, Blog: blog}:
+ case <-gctx.Done():
+ return gctx.Err()
+ }
}
+ return nil
+ })
+
+ if err := g.Wait(); err != nil {
+ return nil, err
}
return results, nil
}
-func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) {
+func (s *Scanner) ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) {
blog, err := db.GetBlogByName(ctx, name)
if err != nil {
return nil, err
@@ -178,7 +209,10 @@ func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*Sc
if blog == nil {
return nil, nil
}
- result := ScanBlog(ctx, db, *blog)
+ result, err := s.ScanBlog(ctx, db, *blog)
+ if err != nil {
+ return nil, err
+ }
return &result, nil
}
@@ -191,6 +225,7 @@ func convertFeedArticles(blogID int64, articles []rss.FeedArticle) []model.Artic
URL: article.URL,
PublishedDate: article.PublishedDate,
IsRead: false,
+ Categories: article.Categories,
})
}
return result
diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go
index 014bba7..374a4cf 100644
--- a/internal/scanner/scanner_test.go
+++ b/internal/scanner/scanner_test.go
@@ -2,6 +2,7 @@ package scanner
import (
"context"
+ "fmt"
"net/http"
"net/http/httptest"
"path/filepath"
@@ -9,6 +10,8 @@ import (
"time"
"github.com/JulienTant/blogwatcher-cli/internal/model"
+ "github.com/JulienTant/blogwatcher-cli/internal/rss"
+ "github.com/JulienTant/blogwatcher-cli/internal/scraper"
"github.com/JulienTant/blogwatcher-cli/internal/storage"
"github.com/stretchr/testify/require"
)
@@ -28,6 +31,11 @@ const sampleFeed = `
`
+func newTestScanner() *Scanner {
+ client := &http.Client{Timeout: 2 * time.Second}
+ return NewScanner(rss.NewFetcher(client), scraper.NewScraper(client))
+}
+
func TestScanBlogRSS(t *testing.T) {
ctx := context.Background()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -44,11 +52,12 @@ func TestScanBlogRSS(t *testing.T) {
blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com", FeedURL: server.URL})
require.NoError(t, err, "add blog")
- result := ScanBlog(ctx, db, blog)
+ result, scanErr := newTestScanner().ScanBlog(ctx, db, blog)
+ require.NoError(t, scanErr)
require.Equal(t, 2, result.NewArticles)
require.Equal(t, "rss", result.Source)
- articles, err := db.ListArticles(ctx, false, nil)
+ articles, err := db.ListArticles(ctx, false, nil, nil)
require.NoError(t, err, "list articles")
require.Len(t, articles, 2)
}
@@ -81,31 +90,46 @@ func TestScanBlogScraperFallback(t *testing.T) {
blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: server.URL, FeedURL: server.URL + "/feed.xml", ScrapeSelector: "article h2 a"})
require.NoError(t, err, "add blog")
- result := ScanBlog(ctx, db, blog)
+ result, scanErr := newTestScanner().ScanBlog(ctx, db, blog)
+ require.NoError(t, scanErr)
require.Equal(t, "scraper", result.Source)
require.Equal(t, 1, result.NewArticles)
- require.Empty(t, result.Error)
}
func TestScanAllBlogsConcurrent(t *testing.T) {
ctx := context.Background()
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil {
- http.Error(w, writeErr.Error(), http.StatusInternalServerError)
- return
- }
- }))
+
+ feedTemplate := `
+%s
+- Post 1https://%s.example.com/1
+- Post 2https://%s.example.com/2
+`
+
+ mux := http.NewServeMux()
+ for _, name := range []string{"a", "b"} {
+ feed := fmt.Sprintf(feedTemplate, name, name, name)
+ mux.HandleFunc("/"+name+"/feed", func(w http.ResponseWriter, r *http.Request) {
+ if _, writeErr := w.Write([]byte(feed)); writeErr != nil {
+ http.Error(w, writeErr.Error(), http.StatusInternalServerError)
+ }
+ })
+ }
+ server := httptest.NewServer(mux)
defer server.Close()
db := openTestDB(t)
defer func() { require.NoError(t, db.Close()) }()
- for i, name := range []string{"TestA", "TestB"} {
- _, err := db.AddBlog(ctx, model.Blog{Name: name, URL: "https://example.com/" + name, FeedURL: server.URL})
- require.NoError(t, err, "add blog %d", i)
+ for _, name := range []string{"a", "b"} {
+ _, err := db.AddBlog(ctx, model.Blog{
+ Name: "Test-" + name,
+ URL: "https://" + name + ".example.com",
+ FeedURL: server.URL + "/" + name + "/feed",
+ })
+ require.NoError(t, err, "add blog %s", name)
}
- results, err := ScanAllBlogs(ctx, db, 2)
+ results, err := newTestScanner().ScanAllBlogs(ctx, db, 2)
require.NoError(t, err, "scan all blogs")
require.Len(t, results, 2)
}
@@ -137,10 +161,69 @@ func TestScanBlogRespectsExistingArticles(t *testing.T) {
_, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "First", URL: "https://example.com/1", DiscoveredDate: ptrTime(time.Now())})
require.NoError(t, err, "add article")
- result := ScanBlog(ctx, db, blog)
+ result, scanErr := newTestScanner().ScanBlog(ctx, db, blog)
+ require.NoError(t, scanErr)
require.Equal(t, 1, result.NewArticles)
}
+func TestScanBlogRSSWithCategories(t *testing.T) {
+ ctx := context.Background()
+ feedWithCategories := `
+
+
+Example Feed
+-
+First
+https://example.com/1
+Go
+Programming
+
+-
+Second
+https://example.com/2
+
+
+`
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if _, writeErr := w.Write([]byte(feedWithCategories)); writeErr != nil {
+ http.Error(w, writeErr.Error(), http.StatusInternalServerError)
+ return
+ }
+ }))
+ defer server.Close()
+
+ db := openTestDB(t)
+ defer func() { require.NoError(t, db.Close()) }()
+
+ blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com", FeedURL: server.URL})
+ require.NoError(t, err, "add blog")
+
+ result, scanErr := newTestScanner().ScanBlog(ctx, db, blog)
+ require.NoError(t, scanErr)
+ require.Equal(t, 2, result.NewArticles)
+
+ articles, err := db.ListArticles(ctx, false, nil, nil)
+ require.NoError(t, err, "list articles")
+ require.Len(t, articles, 2)
+
+ // Find the article with categories
+ var withCat *model.Article
+ var withoutCat *model.Article
+ for i := range articles {
+ if articles[i].Title == "First" {
+ withCat = &articles[i]
+ } else {
+ withoutCat = &articles[i]
+ }
+ }
+ require.NotNil(t, withCat)
+ require.Equal(t, []string{"Go", "Programming"}, withCat.Categories)
+
+ require.NotNil(t, withoutCat)
+ require.Nil(t, withoutCat.Categories)
+}
+
func ptrTime(value time.Time) *time.Time {
return &value
}
diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go
index b56c1cb..8507bd6 100644
--- a/internal/scraper/scraper.go
+++ b/internal/scraper/scraper.go
@@ -27,13 +27,22 @@ func (e ScrapeError) Error() string {
return e.Message
}
-func ScrapeBlog(ctx context.Context, blogURL string, selector string, timeout time.Duration) ([]ScrapedArticle, error) {
- client := &http.Client{Timeout: timeout}
+// Scraper scrapes HTML pages for article links.
+type Scraper struct {
+ client *http.Client
+}
+
+// NewScraper creates a Scraper with the given HTTP client.
+func NewScraper(client *http.Client) *Scraper {
+ return &Scraper{client: client}
+}
+
+func (s *Scraper) ScrapeBlog(ctx context.Context, blogURL string, selector string) ([]ScrapedArticle, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil)
if err != nil {
return nil, ScrapeError{Message: fmt.Sprintf("failed to create request: %v", err)}
}
- response, err := client.Do(req)
+ response, err := s.client.Do(req)
if err != nil {
return nil, ScrapeError{Message: fmt.Sprintf("failed to fetch page: %v", err)}
}
diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go
index 6f288b4..52fe7e9 100644
--- a/internal/scraper/scraper_test.go
+++ b/internal/scraper/scraper_test.go
@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/require"
)
+func newTestScraper() *Scraper {
+ return NewScraper(&http.Client{Timeout: 2 * time.Second})
+}
+
func TestScrapeBlog(t *testing.T) {
html := `
@@ -28,7 +32,7 @@ func TestScrapeBlog(t *testing.T) {
}))
defer server.Close()
- articles, err := ScrapeBlog(context.Background(), server.URL, "article h2 a, .post", 2*time.Second)
+ articles, err := newTestScraper().ScrapeBlog(context.Background(), server.URL, "article h2 a, .post")
require.NoError(t, err, "scrape blog")
require.Len(t, articles, 2)
require.NotEmpty(t, articles[0].URL)
diff --git a/internal/storage/database.go b/internal/storage/database.go
index a2c2ec0..e9da74c 100644
--- a/internal/storage/database.go
+++ b/internal/storage/database.go
@@ -3,6 +3,7 @@ package storage
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"fmt"
"os"
@@ -218,9 +219,13 @@ func (db *Database) RemoveBlog(ctx context.Context, id int64) (bool, error) {
// Article operations
func (db *Database) AddArticle(ctx context.Context, article model.Article) (model.Article, error) {
+ cats, err := categoriesToJSON(article.Categories)
+ if err != nil {
+ return article, err
+ }
result, err := sq.Insert("articles").
- Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read").
- Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead).
+ Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories").
+ Values(article.BlogID, article.Title, article.URL, formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), article.IsRead, cats).
RunWith(db.conn).
ExecContext(ctx)
if err != nil {
@@ -244,8 +249,15 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl
}
insert := sq.Insert("articles").
- Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read")
+ Columns("blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories")
for _, article := range articles {
+ cats, err := categoriesToJSON(article.Categories)
+ if err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ fmt.Fprintf(os.Stderr, "rollback: %v\n", rerr)
+ }
+ return 0, err
+ }
insert = insert.Values(
article.BlogID,
article.Title,
@@ -253,6 +265,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl
formatTimePtr(article.PublishedDate),
formatTimePtr(article.DiscoveredDate),
article.IsRead,
+ cats,
)
}
@@ -271,7 +284,7 @@ func (db *Database) AddArticlesBulk(ctx context.Context, articles []model.Articl
}
func (db *Database) GetArticle(ctx context.Context, id int64) (*model.Article, error) {
- row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read").
+ row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories").
From("articles").
Where(sq.Eq{"id": id}).
RunWith(db.conn).
@@ -280,7 +293,7 @@ func (db *Database) GetArticle(ctx context.Context, id int64) (*model.Article, e
}
func (db *Database) GetArticleByURL(ctx context.Context, url string) (*model.Article, error) {
- row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read").
+ row := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories").
From("articles").
Where(sq.Eq{"url": url}).
RunWith(db.conn).
@@ -347,8 +360,8 @@ func (db *Database) GetExistingArticleURLs(ctx context.Context, urls []string) (
return result, nil
}
-func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *int64) ([]model.Article, error) {
- query := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read").
+func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *int64, category *string) ([]model.Article, error) {
+ query := sq.Select("id", "blog_id", "title", "url", "published_date", "discovered_date", "is_read", "categories").
From("articles").
OrderBy("discovered_date DESC")
@@ -358,6 +371,11 @@ func (db *Database) ListArticles(ctx context.Context, unreadOnly bool, blogID *i
if blogID != nil {
query = query.Where(sq.Eq{"blog_id": *blogID})
}
+ if category != nil && *category != "" {
+ // Categories are stored as a JSON string array. Use json_each()
+ // for exact element matching.
+ query = query.Where("EXISTS (SELECT 1 FROM json_each(categories) WHERE LOWER(json_each.value) = LOWER(?))", *category)
+ }
rows, err := query.RunWith(db.conn).QueryContext(ctx)
if err != nil {
@@ -456,20 +474,27 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article,
publishedDate sql.NullString
discovered sql.NullString
isRead bool
+ categories sql.NullString
)
- if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead); err != nil {
+ if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead, &categories); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
+ cats, err := categoriesFromJSON(categories)
+ if err != nil {
+ return nil, err
+ }
+
article := &model.Article{
- ID: id,
- BlogID: blogID,
- Title: title,
- URL: url,
- IsRead: isRead,
+ ID: id,
+ BlogID: blogID,
+ Title: title,
+ URL: url,
+ IsRead: isRead,
+ Categories: cats,
}
if publishedDate.Valid {
if parsed, err := parseTime(publishedDate.String); err == nil {
@@ -512,3 +537,26 @@ func nullIfEmpty(value string) *string {
}
return &value
}
+
+func categoriesToJSON(categories []string) (*string, error) {
+ if len(categories) == 0 {
+ return nil, nil
+ }
+ b, err := json.Marshal(categories)
+ if err != nil {
+ return nil, fmt.Errorf("marshal categories: %w", err)
+ }
+ s := string(b)
+ return &s, nil
+}
+
+func categoriesFromJSON(s sql.NullString) ([]string, error) {
+ if !s.Valid || s.String == "" {
+ return nil, nil
+ }
+ var cats []string
+ if err := json.Unmarshal([]byte(s.String), &cats); err != nil {
+ return nil, fmt.Errorf("unmarshal categories: %w", err)
+ }
+ return cats, nil
+}
diff --git a/internal/storage/database_test.go b/internal/storage/database_test.go
index 29f03e4..531da09 100644
--- a/internal/storage/database_test.go
+++ b/internal/storage/database_test.go
@@ -39,7 +39,7 @@ func TestDatabaseCreatesFileAndCRUD(t *testing.T) {
require.NoError(t, err, "add articles bulk")
require.Equal(t, 2, count)
- list, err := db.ListArticles(ctx, false, nil)
+ list, err := db.ListArticles(ctx, false, nil, nil)
require.NoError(t, err, "list articles")
require.Len(t, list, 2)
@@ -190,17 +190,17 @@ func TestListArticlesFiltersAndOrdering(t *testing.T) {
_, err = db.MarkArticleRead(ctx, first.ID)
require.NoError(t, err, "mark read")
- all, err := db.ListArticles(ctx, false, nil)
+ all, err := db.ListArticles(ctx, false, nil, nil)
require.NoError(t, err, "list articles")
require.Len(t, all, 3)
require.Equal(t, second.ID, all[0].ID, "expected newest article first")
- unread, err := db.ListArticles(ctx, true, nil)
+ unread, err := db.ListArticles(ctx, true, nil, nil)
require.NoError(t, err, "list unread")
require.Len(t, unread, 2)
blogID := blogB.ID
- filtered, err := db.ListArticles(ctx, false, &blogID)
+ filtered, err := db.ListArticles(ctx, false, &blogID, nil)
require.NoError(t, err, "list by blog")
require.Len(t, filtered, 1)
require.Equal(t, blogB.ID, filtered[0].BlogID)
@@ -231,7 +231,7 @@ func TestBulkInsertDuplicateRollbackAndEmpty(t *testing.T) {
_, err = db.AddArticlesBulk(ctx, dupArticles)
require.Error(t, err, "expected bulk insert to fail on duplicate url")
- articles, err := db.ListArticles(ctx, false, nil)
+ articles, err := db.ListArticles(ctx, false, nil, nil)
require.NoError(t, err, "list articles")
require.Len(t, articles, 1, "expected rollback on duplicate")
}
@@ -269,3 +269,156 @@ func TestLookupHelpers(t *testing.T) {
require.NoError(t, err)
require.False(t, exists)
}
+
+func TestArticleCategoriesRoundTrip(t *testing.T) {
+ ctx := context.Background()
+ db := openTestDB(t)
+ defer func() { require.NoError(t, db.Close()) }()
+
+ blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"})
+ require.NoError(t, err, "add blog")
+
+ // Article with categories
+ article, err := db.AddArticle(ctx, model.Article{
+ BlogID: blog.ID,
+ Title: "Categorized",
+ URL: "https://example.com/cat",
+ Categories: []string{"Go", "Programming"},
+ })
+ require.NoError(t, err, "add article with categories")
+
+ fetched, err := db.GetArticle(ctx, article.ID)
+ require.NoError(t, err, "get article")
+ require.NotNil(t, fetched)
+ require.Equal(t, []string{"Go", "Programming"}, fetched.Categories)
+
+ // Article without categories
+ articleNoCat, err := db.AddArticle(ctx, model.Article{
+ BlogID: blog.ID,
+ Title: "No Category",
+ URL: "https://example.com/nocat",
+ })
+ require.NoError(t, err, "add article without categories")
+
+ fetchedNoCat, err := db.GetArticle(ctx, articleNoCat.ID)
+ require.NoError(t, err, "get article")
+ require.NotNil(t, fetchedNoCat)
+ require.Nil(t, fetchedNoCat.Categories)
+}
+
+func TestListArticlesFilterByCategory(t *testing.T) {
+ ctx := context.Background()
+ db := openTestDB(t)
+ defer func() { require.NoError(t, db.Close()) }()
+
+ blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"})
+ require.NoError(t, err, "add blog")
+
+ t1 := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC)
+ t2 := time.Date(2024, 1, 1, 11, 0, 0, 0, time.UTC)
+ t3 := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
+
+ _, err = db.AddArticle(ctx, model.Article{
+ BlogID: blog.ID,
+ Title: "Go Article",
+ URL: "https://example.com/go",
+ DiscoveredDate: &t1,
+ Categories: []string{"Go", "Programming"},
+ })
+ require.NoError(t, err, "add go article")
+
+ _, err = db.AddArticle(ctx, model.Article{
+ BlogID: blog.ID,
+ Title: "Rust Article",
+ URL: "https://example.com/rust",
+ DiscoveredDate: &t2,
+ Categories: []string{"Rust", "Programming"},
+ })
+ require.NoError(t, err, "add rust article")
+
+ _, err = db.AddArticle(ctx, model.Article{
+ BlogID: blog.ID,
+ Title: "No Category",
+ URL: "https://example.com/nocat",
+ DiscoveredDate: &t3,
+ })
+ require.NoError(t, err, "add no-cat article")
+
+ // Filter by "Go" - should return only the Go article
+ cat := "Go"
+ goArticles, err := db.ListArticles(ctx, false, nil, &cat)
+ require.NoError(t, err, "list by category Go")
+ require.Len(t, goArticles, 1)
+ require.Equal(t, "Go Article", goArticles[0].Title)
+
+ // Filter by "Programming" - should return both categorized articles
+ cat = "Programming"
+ progArticles, err := db.ListArticles(ctx, false, nil, &cat)
+ require.NoError(t, err, "list by category Programming")
+ require.Len(t, progArticles, 2)
+
+ // No filter - should return all 3
+ all, err := db.ListArticles(ctx, false, nil, nil)
+ require.NoError(t, err, "list all")
+ require.Len(t, all, 3)
+
+ // Case-insensitive match - "go" should match "Go"
+ cat = "go"
+ goLower, err := db.ListArticles(ctx, false, nil, &cat)
+ require.NoError(t, err, "list by category go (lowercase)")
+ require.Len(t, goLower, 1)
+ require.Equal(t, "Go Article", goLower[0].Title)
+
+ // Case-insensitive match - "PROGRAMMING" should match "Programming"
+ cat = "PROGRAMMING"
+ progUpper, err := db.ListArticles(ctx, false, nil, &cat)
+ require.NoError(t, err, "list by category PROGRAMMING (uppercase)")
+ require.Len(t, progUpper, 2)
+
+ // Empty string category should return all
+ empty := ""
+ allEmpty, err := db.ListArticles(ctx, false, nil, &empty)
+ require.NoError(t, err, "list with empty category")
+ require.Len(t, allEmpty, 3)
+}
+
+func TestBulkInsertWithCategories(t *testing.T) {
+ ctx := context.Background()
+ db := openTestDB(t)
+ defer func() { require.NoError(t, db.Close()) }()
+
+ blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com"})
+ require.NoError(t, err, "add blog")
+
+ articles := []model.Article{
+ {BlogID: blog.ID, Title: "One", URL: "https://example.com/1", Categories: []string{"AI", "ML"}},
+ {BlogID: blog.ID, Title: "Two", URL: "https://example.com/2"},
+ }
+ count, err := db.AddArticlesBulk(ctx, articles)
+ require.NoError(t, err, "bulk insert")
+ require.Equal(t, 2, count)
+
+ list, err := db.ListArticles(ctx, false, nil, nil)
+ require.NoError(t, err, "list articles")
+ require.Len(t, list, 2)
+
+ // Find the one with categories (order is by discovered_date DESC, both nil so order may vary)
+ var withCat *model.Article
+ for i := range list {
+ if list[i].Title == "One" {
+ withCat = &list[i]
+ break
+ }
+ }
+ require.NotNil(t, withCat, "expected article with categories")
+ require.Equal(t, []string{"AI", "ML"}, withCat.Categories)
+}
+
+func openTestDB(t *testing.T) *Database {
+ t.Helper()
+ ctx := context.Background()
+ path := filepath.Join(t.TempDir(), "blogwatcher-cli.db")
+ db, err := OpenDatabase(ctx, path)
+ require.NoError(t, err, "open database")
+ return db
+}
diff --git a/internal/storage/migrations/000002_add_categories.down.sql b/internal/storage/migrations/000002_add_categories.down.sql
new file mode 100644
index 0000000..0316e9b
--- /dev/null
+++ b/internal/storage/migrations/000002_add_categories.down.sql
@@ -0,0 +1,18 @@
+-- SQLite does not support DROP COLUMN prior to 3.35.0.
+-- Recreate the table without the categories column.
+CREATE TABLE articles_backup (
+ id INTEGER PRIMARY KEY,
+ blog_id INTEGER NOT NULL,
+ title TEXT NOT NULL,
+ url TEXT NOT NULL UNIQUE,
+ published_date TIMESTAMP,
+ discovered_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ is_read BOOLEAN DEFAULT FALSE,
+ FOREIGN KEY (blog_id) REFERENCES blogs(id)
+);
+
+INSERT INTO articles_backup SELECT id, blog_id, title, url, published_date, discovered_date, is_read FROM articles;
+
+DROP TABLE articles;
+
+ALTER TABLE articles_backup RENAME TO articles;
diff --git a/internal/storage/migrations/000002_add_categories.up.sql b/internal/storage/migrations/000002_add_categories.up.sql
new file mode 100644
index 0000000..d3f8dbc
--- /dev/null
+++ b/internal/storage/migrations/000002_add_categories.up.sql
@@ -0,0 +1 @@
+ALTER TABLE articles ADD COLUMN categories TEXT;