From 75c6119617b5fcef2b02e9af9ebadffd413797cf Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 14:58:52 -0700 Subject: [PATCH 1/2] Add SSRF-safe HTTP client with --unsafe-client escape hatch - Use DataDog go-secure-sdk for SSRF protection (blocks private IPs, loopback, link-local, cloud metadata endpoints) - Refactor rss.Fetcher and scraper.Scraper as structs holding *http.Client - Scanner accepts Fetcher + Scraper as dependencies (no global state) - Add --unsafe-client / BLOGWATCHER_UNSAFE_CLIENT flag for local dev - ScanBlog now returns error (hard fail, not soft error string) - E2e test verifies safe client blocks loopback requests - Tests use plain http.Client (no SSRF check needed for httptest) - Update AGENTS.md: ground claims in evidence, not guesses Co-Authored-By: Claude Opus 4.6 (1M context) --- AGENTS.md | 4 ++ e2e/e2e_test.go | 28 ++++++++-- go.mod | 2 +- go.sum | 12 +++-- internal/cli/commands.go | 24 ++++++--- internal/cli/root.go | 1 + internal/rss/rss.go | 27 ++++++---- internal/rss/rss_test.go | 8 ++- internal/scanner/scanner.go | 89 ++++++++++++++++++++------------ internal/scanner/scanner_test.go | 53 ++++++++++++++----- internal/scraper/scraper.go | 15 ++++-- internal/scraper/scraper_test.go | 6 ++- 12 files changed, 193 insertions(+), 76 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index d6c29cf..855e854 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,3 +1,7 @@ +### Before speaking + +Never guess. Ground every claim in actual files, docs, or command output. If you're unsure about a dependency's size, run `go get` and check. If you're unsure about an API, read the source or docs. If you're unsure about behavior, write a test. State what you verified and how — don't speculate. + ### While developing #### Tests, tests, tests! diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 29caf9b..0dbd930 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -57,11 +57,11 @@ func (c *cliOpts) run(t *testing.T, args []string, opts map[string]string) (stdo var cmdArgs []string var extraEnv []string - // DB path goes before the subcommand (persistent flag). + // Persistent flags / env vars. if c.mode == "flags" { - cmdArgs = append(cmdArgs, "--db", c.dbPath) + cmdArgs = append(cmdArgs, "--db", c.dbPath, "--unsafe-client") } else { - extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath) + extraEnv = append(extraEnv, "BLOGWATCHER_DB="+c.dbPath, "BLOGWATCHER_UNSAFE_CLIENT=true") } // Positional args first (includes the subcommand name). @@ -402,6 +402,28 @@ func TestE2E(t *testing.T) { } } +func TestSSRFProtection(t *testing.T) { + baseURL := startTestServer(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + // Add a blog pointing to the loopback test server WITHOUT --unsafe-client. + // The add command doesn't fetch, so it should succeed. + cmd := exec.CommandContext(context.Background(), binaryPath, + "--db", dbPath, "add", "test-blog", baseURL+"/go/", + "--feed-url", baseURL+"/go/feed.atom") + cmd.Env = append(os.Environ(), "NO_COLOR=1") + out, err := cmd.CombinedOutput() + require.NoError(t, err, "add should succeed: %s", string(out)) + + // Scan WITHOUT --unsafe-client — the safe client should block loopback and fail. + cmd = exec.CommandContext(context.Background(), binaryPath, + "--db", dbPath, "scan") + cmd.Env = append(os.Environ(), "NO_COLOR=1") + out, err = cmd.CombinedOutput() + require.Error(t, err, "scan should fail when SSRF protection blocks loopback") + require.Contains(t, string(out), "is not authorized", "expected SSRF error message") +} + func extractFirstID(t *testing.T, output string) string { t.Helper() re := regexp.MustCompile(`\[(\d+)\]`) diff --git a/go.mod b/go.mod index c2e1ea3..8b8cdd8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/JulienTant/blogwatcher-cli go 1.26.1 require ( + github.com/DataDog/go-secure-sdk v0.0.7 github.com/Masterminds/squirrel v1.5.4 github.com/PuerkitoBio/goquery v1.12.0 github.com/fatih/color v1.19.0 @@ -44,7 +45,6 @@ require ( golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/tools v0.43.0 // indirect - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index d244cad..83d1bb7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DataDog/go-secure-sdk v0.0.7 h1:FJYJPXXZmxOC0RHWzhsZ3PKtUaF+IgC6fDc/PbEoH6c= +github.com/DataDog/go-secure-sdk v0.0.7/go.mod h1:/fIdMM7LMp7KGouxN9Cnv3UP0NdfP+XZ5R8qyU3D8qk= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= github.com/PuerkitoBio/goquery v1.12.0 h1:pAcL4g3WRXekcB9AU/y1mbKez2dbY2AajVhtkO8RIBo= @@ -21,6 +23,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -66,8 +70,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= @@ -169,8 +173,8 @@ golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 3e4a504..e356cb6 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -3,21 +3,34 @@ package cli import ( "bufio" "fmt" + "net/http" "os" "strconv" "strings" "time" + "github.com/DataDog/go-secure-sdk/net/httpclient" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/JulienTant/blogwatcher-cli/internal/controller" "github.com/JulienTant/blogwatcher-cli/internal/model" + "github.com/JulienTant/blogwatcher-cli/internal/rss" "github.com/JulienTant/blogwatcher-cli/internal/scanner" + "github.com/JulienTant/blogwatcher-cli/internal/scraper" "github.com/JulienTant/blogwatcher-cli/internal/storage" ) +const httpTimeout = 30 * time.Second + +func newHTTPClient() *http.Client { + if viper.GetBool("unsafe-client") { + return httpclient.UnSafe(httpclient.WithTimeout(httpTimeout)) + } + return httpclient.Safe(httpclient.WithTimeout(httpTimeout)) +} + func newAddCommand() *cobra.Command { cmd := &cobra.Command{ Use: "add ", @@ -148,8 +161,11 @@ func newScanCommand() *cobra.Command { } }() + client := newHTTPClient() + sc := scanner.NewScanner(rss.NewFetcher(client), scraper.NewScraper(client)) + if len(args) == 1 { - result, err := scanner.ScanBlogByName(cmd.Context(), db, args[0]) + result, err := sc.ScanBlogByName(cmd.Context(), db, args[0]) if err != nil { return err } @@ -173,7 +189,7 @@ func newScanCommand() *cobra.Command { if !silent { cprintf([]color.Attribute{color.FgCyan}, "Scanning %d blog(s)...\n\n", len(blogs)) } - results, err := scanner.ScanAllBlogs(cmd.Context(), db, workers) + results, err := sc.ScanAllBlogs(cmd.Context(), db, workers) if err != nil { return err } @@ -385,10 +401,6 @@ func printScanResult(result scanner.ScanResult) { statusColor = []color.Attribute{color.FgGreen} } cprintf([]color.Attribute{color.FgWhite, color.Bold}, " %s\n", result.BlogName) - if result.Error != "" { - cprintfErr(color.FgRed, " Error: %s\n", result.Error) - return - } if result.Source == "none" { cprintln([]color.Attribute{color.FgYellow}, " No feed or scraper configured") return diff --git a/internal/cli/root.go b/internal/cli/root.go index f822f7e..161840a 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -24,6 +24,7 @@ func NewRootCommand() *cobra.Command { rootCmd.SetVersionTemplate("{{.Version}}\n") rootCmd.PersistentFlags().String("db", "", "Path to the SQLite database file (default: ~/.blogwatcher-cli/blogwatcher-cli.db)") + rootCmd.PersistentFlags().Bool("unsafe-client", false, "Disable SSRF protection (allow requests to private/loopback IPs)") rootCmd.AddCommand(newAddCommand()) rootCmd.AddCommand(newRemoveCommand()) diff --git a/internal/rss/rss.go b/internal/rss/rss.go index 1107eb9..87021e8 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -28,13 +28,22 @@ func (e FeedParseError) Error() string { return e.Message } -func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]FeedArticle, error) { - client := &http.Client{Timeout: timeout} +// Fetcher fetches and parses RSS/Atom feeds. +type Fetcher struct { + client *http.Client +} + +// NewFetcher creates a Fetcher with the given HTTP client. +func NewFetcher(client *http.Client) *Fetcher { + return &Fetcher{client: client} +} + +func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) if err != nil { return nil, FeedParseError{Message: fmt.Sprintf("failed to create request: %v", err)} } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return nil, FeedParseError{Message: fmt.Sprintf("failed to fetch feed: %v", err)} } @@ -70,13 +79,12 @@ func ParseFeed(ctx context.Context, feedURL string, timeout time.Duration) ([]Fe return articles, nil } -func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) (string, error) { - client := &http.Client{Timeout: timeout} +func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { return "", nil } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return "", nil } @@ -138,7 +146,7 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) if resolved == "" { continue } - ok, err := isValidFeed(ctx, resolved, timeout) + ok, err := f.isValidFeed(ctx, resolved) if err == nil && ok { return resolved, nil } @@ -147,13 +155,12 @@ func DiscoverFeedURL(ctx context.Context, blogURL string, timeout time.Duration) return "", nil } -func isValidFeed(ctx context.Context, feedURL string, timeout time.Duration) (bool, error) { - client := &http.Client{Timeout: timeout} +func (f *Fetcher) isValidFeed(ctx context.Context, feedURL string) (bool, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) if err != nil { return false, err } - response, err := client.Do(req) + response, err := f.client.Do(req) if err != nil { return false, err } diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go index 7fefca9..6dcd987 100644 --- a/internal/rss/rss_test.go +++ b/internal/rss/rss_test.go @@ -26,6 +26,10 @@ const sampleFeed = ` ` +func newTestFetcher() *Fetcher { + return NewFetcher(&http.Client{Timeout: 2 * time.Second}) +} + func TestParseFeed(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil { @@ -35,7 +39,7 @@ func TestParseFeed(t *testing.T) { })) defer server.Close() - articles, err := ParseFeed(context.Background(), server.URL, 2*time.Second) + articles, err := newTestFetcher().ParseFeed(context.Background(), server.URL) require.NoError(t, err, "parse feed") require.Len(t, articles, 2) require.NotNil(t, articles[0].PublishedDate) @@ -58,7 +62,7 @@ func TestDiscoverFeedURL(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - feedURL, err := DiscoverFeedURL(context.Background(), server.URL, 2*time.Second) + feedURL, err := newTestFetcher().DiscoverFeedURL(context.Background(), server.URL) require.NoError(t, err, "discover feed") require.NotEmpty(t, feedURL, "expected feed url") } diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 7997588..5271dc5 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -17,50 +17,65 @@ type ScanResult struct { NewArticles int TotalFound int Source string - Error string } -func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanResult { +// Scanner orchestrates blog scanning using a Fetcher and Scraper. +type Scanner struct { + fetcher *rss.Fetcher + scraper *scraper.Scraper +} + +// NewScanner creates a Scanner with the given fetcher and scraper. +func NewScanner(fetcher *rss.Fetcher, scraper *scraper.Scraper) *Scanner { + return &Scanner{fetcher: fetcher, scraper: scraper} +} + +func (s *Scanner) ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) (ScanResult, error) { var ( articles []model.Article source = "none" - errText string ) feedURL := blog.FeedURL if feedURL == "" { - if discovered, err := rss.DiscoverFeedURL(ctx, blog.URL, 30*time.Second); err == nil && discovered != "" { + discovered, err := s.fetcher.DiscoverFeedURL(ctx, blog.URL) + if err != nil { + return ScanResult{BlogName: blog.Name}, err + } + if discovered != "" { feedURL = discovered blog.FeedURL = discovered if err := db.UpdateBlog(ctx, blog); err != nil { - fmt.Fprintf(os.Stderr, "update blog: %v\n", err) + return ScanResult{BlogName: blog.Name}, err } } } if feedURL != "" { - feedArticles, err := rss.ParseFeed(ctx, feedURL, 30*time.Second) + feedArticles, err := s.fetcher.ParseFeed(ctx, feedURL) if err != nil { - errText = err.Error() + // If there's a scraper fallback, try it before giving up. + if blog.ScrapeSelector == "" { + return ScanResult{BlogName: blog.Name}, err + } + // Try scraper as fallback. + scrapedArticles, scrapeErr := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector) + if scrapeErr != nil { + return ScanResult{BlogName: blog.Name}, fmt.Errorf("RSS: %w; Scraper: %w", err, scrapeErr) + } + articles = convertScrapedArticles(blog.ID, scrapedArticles) + source = "scraper" } else { articles = convertFeedArticles(blog.ID, feedArticles) source = "rss" } - } - - if len(articles) == 0 && blog.ScrapeSelector != "" { - scrapedArticles, err := scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector, 30*time.Second) + } else if blog.ScrapeSelector != "" { + scrapedArticles, err := s.scraper.ScrapeBlog(ctx, blog.URL, blog.ScrapeSelector) if err != nil { - if errText != "" { - errText = fmt.Sprintf("RSS: %s; Scraper: %s", errText, err.Error()) - } else { - errText = err.Error() - } - } else { - articles = convertScrapedArticles(blog.ID, scrapedArticles) - source = "scraper" - errText = "" + return ScanResult{BlogName: blog.Name}, err } + articles = convertScrapedArticles(blog.ID, scrapedArticles) + source = "scraper" } seenURLs := make(map[string]struct{}) @@ -80,7 +95,7 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe existing, err := db.GetExistingArticleURLs(ctx, urlList) if err != nil { - errText = err.Error() + return ScanResult{BlogName: blog.Name}, err } discoveredAt := time.Now() @@ -97,14 +112,13 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe if len(newArticles) > 0 { count, err := db.AddArticlesBulk(ctx, newArticles) if err != nil { - errText = err.Error() - } else { - newCount = count + return ScanResult{BlogName: blog.Name}, err } + newCount = count } if err := db.UpdateBlogLastScanned(ctx, blog.ID, time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "update last scanned: %v\n", err) + return ScanResult{BlogName: blog.Name}, err } return ScanResult{ @@ -112,11 +126,10 @@ func ScanBlog(ctx context.Context, db *storage.Database, blog model.Blog) ScanRe NewArticles: newCount, TotalFound: len(seenURLs), Source: source, - Error: errText, - } + }, nil } -func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) { +func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]ScanResult, error) { blogs, err := db.ListBlogs(ctx) if err != nil { return nil, err @@ -124,7 +137,11 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca if workers <= 1 { results := make([]ScanResult, 0, len(blogs)) for _, blog := range blogs { - results = append(results, ScanBlog(ctx, db, blog)) + result, err := s.ScanBlog(ctx, db, blog) + if err != nil { + return nil, fmt.Errorf("scan %s: %w", blog.Name, err) + } + results = append(results, result) } return results, nil } @@ -150,7 +167,12 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca } }() for item := range jobs { - results[item.Index] = ScanBlog(ctx, workerDB, item.Blog) + result, err := s.ScanBlog(ctx, workerDB, item.Blog) + if err != nil { + errs <- fmt.Errorf("scan %s: %w", item.Blog.Name, err) + return + } + results[item.Index] = result } errs <- nil }() @@ -170,7 +192,7 @@ func ScanAllBlogs(ctx context.Context, db *storage.Database, workers int) ([]Sca return results, nil } -func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) { +func (s *Scanner) ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*ScanResult, error) { blog, err := db.GetBlogByName(ctx, name) if err != nil { return nil, err @@ -178,7 +200,10 @@ func ScanBlogByName(ctx context.Context, db *storage.Database, name string) (*Sc if blog == nil { return nil, nil } - result := ScanBlog(ctx, db, *blog) + result, err := s.ScanBlog(ctx, db, *blog) + if err != nil { + return nil, err + } return &result, nil } diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go index 014bba7..dcf0f56 100644 --- a/internal/scanner/scanner_test.go +++ b/internal/scanner/scanner_test.go @@ -2,6 +2,7 @@ package scanner import ( "context" + "fmt" "net/http" "net/http/httptest" "path/filepath" @@ -9,6 +10,8 @@ import ( "time" "github.com/JulienTant/blogwatcher-cli/internal/model" + "github.com/JulienTant/blogwatcher-cli/internal/rss" + "github.com/JulienTant/blogwatcher-cli/internal/scraper" "github.com/JulienTant/blogwatcher-cli/internal/storage" "github.com/stretchr/testify/require" ) @@ -28,6 +31,11 @@ const sampleFeed = ` ` +func newTestScanner() *Scanner { + client := &http.Client{Timeout: 2 * time.Second} + return NewScanner(rss.NewFetcher(client), scraper.NewScraper(client)) +} + func TestScanBlogRSS(t *testing.T) { ctx := context.Background() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -44,7 +52,8 @@ func TestScanBlogRSS(t *testing.T) { blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: "https://example.com", FeedURL: server.URL}) require.NoError(t, err, "add blog") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, 2, result.NewArticles) require.Equal(t, "rss", result.Source) @@ -81,31 +90,46 @@ func TestScanBlogScraperFallback(t *testing.T) { blog, err := db.AddBlog(ctx, model.Blog{Name: "Test", URL: server.URL, FeedURL: server.URL + "/feed.xml", ScrapeSelector: "article h2 a"}) require.NoError(t, err, "add blog") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, "scraper", result.Source) require.Equal(t, 1, result.NewArticles) - require.Empty(t, result.Error) } func TestScanAllBlogsConcurrent(t *testing.T) { ctx := context.Background() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, writeErr := w.Write([]byte(sampleFeed)); writeErr != nil { - http.Error(w, writeErr.Error(), http.StatusInternalServerError) - return - } - })) + + feedTemplate := ` +%s +Post 1https://%s.example.com/1 +Post 2https://%s.example.com/2 +` + + mux := http.NewServeMux() + for _, name := range []string{"a", "b"} { + feed := fmt.Sprintf(feedTemplate, name, name, name) + mux.HandleFunc("/"+name+"/feed", func(w http.ResponseWriter, r *http.Request) { + if _, writeErr := w.Write([]byte(feed)); writeErr != nil { + http.Error(w, writeErr.Error(), http.StatusInternalServerError) + } + }) + } + server := httptest.NewServer(mux) defer server.Close() db := openTestDB(t) defer func() { require.NoError(t, db.Close()) }() - for i, name := range []string{"TestA", "TestB"} { - _, err := db.AddBlog(ctx, model.Blog{Name: name, URL: "https://example.com/" + name, FeedURL: server.URL}) - require.NoError(t, err, "add blog %d", i) + for _, name := range []string{"a", "b"} { + _, err := db.AddBlog(ctx, model.Blog{ + Name: "Test-" + name, + URL: "https://" + name + ".example.com", + FeedURL: server.URL + "/" + name + "/feed", + }) + require.NoError(t, err, "add blog %s", name) } - results, err := ScanAllBlogs(ctx, db, 2) + results, err := newTestScanner().ScanAllBlogs(ctx, db, 2) require.NoError(t, err, "scan all blogs") require.Len(t, results, 2) } @@ -137,7 +161,8 @@ func TestScanBlogRespectsExistingArticles(t *testing.T) { _, err = db.AddArticle(ctx, model.Article{BlogID: blog.ID, Title: "First", URL: "https://example.com/1", DiscoveredDate: ptrTime(time.Now())}) require.NoError(t, err, "add article") - result := ScanBlog(ctx, db, blog) + result, scanErr := newTestScanner().ScanBlog(ctx, db, blog) + require.NoError(t, scanErr) require.Equal(t, 1, result.NewArticles) } diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go index b56c1cb..8507bd6 100644 --- a/internal/scraper/scraper.go +++ b/internal/scraper/scraper.go @@ -27,13 +27,22 @@ func (e ScrapeError) Error() string { return e.Message } -func ScrapeBlog(ctx context.Context, blogURL string, selector string, timeout time.Duration) ([]ScrapedArticle, error) { - client := &http.Client{Timeout: timeout} +// Scraper scrapes HTML pages for article links. +type Scraper struct { + client *http.Client +} + +// NewScraper creates a Scraper with the given HTTP client. +func NewScraper(client *http.Client) *Scraper { + return &Scraper{client: client} +} + +func (s *Scraper) ScrapeBlog(ctx context.Context, blogURL string, selector string) ([]ScrapedArticle, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { return nil, ScrapeError{Message: fmt.Sprintf("failed to create request: %v", err)} } - response, err := client.Do(req) + response, err := s.client.Do(req) if err != nil { return nil, ScrapeError{Message: fmt.Sprintf("failed to fetch page: %v", err)} } diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go index 6f288b4..52fe7e9 100644 --- a/internal/scraper/scraper_test.go +++ b/internal/scraper/scraper_test.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/require" ) +func newTestScraper() *Scraper { + return NewScraper(&http.Client{Timeout: 2 * time.Second}) +} + func TestScrapeBlog(t *testing.T) { html := ` @@ -28,7 +32,7 @@ func TestScrapeBlog(t *testing.T) { })) defer server.Close() - articles, err := ScrapeBlog(context.Background(), server.URL, "article h2 a, .post", 2*time.Second) + articles, err := newTestScraper().ScrapeBlog(context.Background(), server.URL, "article h2 a, .post") require.NoError(t, err, "scrape blog") require.Len(t, articles, 2) require.NotEmpty(t, articles[0].URL) From 973c50c2e9177255e27ddf525a3741f481b5581b Mon Sep 17 00:00:00 2001 From: Julien Tant Date: Fri, 3 Apr 2026 15:16:47 -0700 Subject: [PATCH 2/2] Fix review findings: DiscoverFeedURL errors, scanner deadlock, e2e env leak - DiscoverFeedURL now propagates real errors (network, context cancel, parse) instead of swallowing them as ("", nil) - Replace scanner worker pool with errgroup to prevent deadlock when a worker returns early on error while sender is still pushing jobs - Filter BLOGWATCHER_UNSAFE_CLIENT from env in SSRF e2e test to ensure the safe client path is always exercised - Check our own wrapped error message in SSRF test instead of SDK string Co-Authored-By: Claude Opus 4.6 (1M context) --- e2e/e2e_test.go | 22 ++++++++++++++++--- go.mod | 1 + internal/rss/rss.go | 9 ++++---- internal/scanner/scanner.go | 42 +++++++++++++++++++++---------------- 4 files changed, 49 insertions(+), 25 deletions(-) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 0dbd930..b4f12ea 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -406,22 +406,38 @@ func TestSSRFProtection(t *testing.T) { baseURL := startTestServer(t) dbPath := filepath.Join(t.TempDir(), "test.db") + // Build a clean env without BLOGWATCHER_UNSAFE_CLIENT to ensure the + // safe client is actually exercised, even if the user's shell has it set. + cleanEnv := filterEnv(os.Environ(), "BLOGWATCHER_UNSAFE_CLIENT") + cleanEnv = append(cleanEnv, "NO_COLOR=1") + // Add a blog pointing to the loopback test server WITHOUT --unsafe-client. // The add command doesn't fetch, so it should succeed. cmd := exec.CommandContext(context.Background(), binaryPath, "--db", dbPath, "add", "test-blog", baseURL+"/go/", "--feed-url", baseURL+"/go/feed.atom") - cmd.Env = append(os.Environ(), "NO_COLOR=1") + cmd.Env = cleanEnv out, err := cmd.CombinedOutput() require.NoError(t, err, "add should succeed: %s", string(out)) // Scan WITHOUT --unsafe-client — the safe client should block loopback and fail. cmd = exec.CommandContext(context.Background(), binaryPath, "--db", dbPath, "scan") - cmd.Env = append(os.Environ(), "NO_COLOR=1") + cmd.Env = cleanEnv out, err = cmd.CombinedOutput() require.Error(t, err, "scan should fail when SSRF protection blocks loopback") - require.Contains(t, string(out), "is not authorized", "expected SSRF error message") + require.Contains(t, string(out), "failed to fetch feed:", "expected our wrapped error message") +} + +func filterEnv(env []string, key string) []string { + prefix := key + "=" + filtered := make([]string, 0, len(env)) + for _, e := range env { + if !strings.HasPrefix(e, prefix) { + filtered = append(filtered, e) + } + } + return filtered } func extractFirstID(t *testing.T, output string) string { diff --git a/go.mod b/go.mod index 8b8cdd8..6738bbf 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 + golang.org/x/sync v0.20.0 modernc.org/sqlite v1.48.1 ) diff --git a/internal/rss/rss.go b/internal/rss/rss.go index 87021e8..b8d08a5 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -82,11 +82,11 @@ func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } response, err := f.client.Do(req) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } defer func() { if err := response.Body.Close(); err != nil { @@ -94,17 +94,18 @@ func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, } }() if response.StatusCode < 200 || response.StatusCode >= 300 { + // Not-found / bad status is not an error — just means no feed at this URL. return "", nil } base, err := url.Parse(blogURL) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } doc, err := goquery.NewDocumentFromReader(response.Body) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: parse HTML: %w", err) } feedTypes := []string{ diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 5271dc5..a5ee086 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -6,6 +6,8 @@ import ( "os" "time" + "golang.org/x/sync/errgroup" + "github.com/JulienTant/blogwatcher-cli/internal/model" "github.com/JulienTant/blogwatcher-cli/internal/rss" "github.com/JulienTant/blogwatcher-cli/internal/scraper" @@ -152,14 +154,14 @@ func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, worker } jobs := make(chan job) results := make([]ScanResult, len(blogs)) - errs := make(chan error, workers) + + g, gctx := errgroup.WithContext(ctx) for i := 0; i < workers; i++ { - go func() { - workerDB, err := storage.OpenDatabase(ctx, db.Path()) + g.Go(func() error { + workerDB, err := storage.OpenDatabase(gctx, db.Path()) if err != nil { - errs <- err - return + return err } defer func() { if err := workerDB.Close(); err != nil { @@ -167,26 +169,30 @@ func (s *Scanner) ScanAllBlogs(ctx context.Context, db *storage.Database, worker } }() for item := range jobs { - result, err := s.ScanBlog(ctx, workerDB, item.Blog) + result, err := s.ScanBlog(gctx, workerDB, item.Blog) if err != nil { - errs <- fmt.Errorf("scan %s: %w", item.Blog.Name, err) - return + return fmt.Errorf("scan %s: %w", item.Blog.Name, err) } results[item.Index] = result } - errs <- nil - }() - } - - for index, blog := range blogs { - jobs <- job{Index: index, Blog: blog} + return nil + }) } - close(jobs) - for i := 0; i < workers; i++ { - if err := <-errs; err != nil { - return nil, err + g.Go(func() error { + defer close(jobs) + for index, blog := range blogs { + select { + case jobs <- job{Index: index, Blog: blog}: + case <-gctx.Done(): + return gctx.Err() + } } + return nil + }) + + if err := g.Wait(); err != nil { + return nil, err } return results, nil