diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 455bbc9..827d855 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -410,22 +410,38 @@ func TestSSRFProtection(t *testing.T) { baseURL := startTestServer(t) dbPath := filepath.Join(t.TempDir(), "test.db") + // Build a clean env without BLOGWATCHER_UNSAFE_CLIENT to ensure the + // safe client is actually exercised, even if the user's shell has it set. + cleanEnv := filterEnv(os.Environ(), "BLOGWATCHER_UNSAFE_CLIENT") + cleanEnv = append(cleanEnv, "NO_COLOR=1") + // Add a blog pointing to the loopback test server WITHOUT --unsafe-client. // The add command doesn't fetch, so it should succeed. cmd := exec.CommandContext(context.Background(), binaryPath, "--db", dbPath, "add", "test-blog", baseURL+"/go/", "--feed-url", baseURL+"/go/feed.atom") - cmd.Env = append(os.Environ(), "NO_COLOR=1") + cmd.Env = cleanEnv out, err := cmd.CombinedOutput() require.NoError(t, err, "add should succeed: %s", string(out)) // Scan WITHOUT --unsafe-client — the safe client should block loopback and fail. cmd = exec.CommandContext(context.Background(), binaryPath, "--db", dbPath, "scan") - cmd.Env = append(os.Environ(), "NO_COLOR=1") + cmd.Env = cleanEnv out, err = cmd.CombinedOutput() require.Error(t, err, "scan should fail when SSRF protection blocks loopback") - require.Contains(t, string(out), "is not authorized", "expected SSRF error message") + require.Contains(t, string(out), "failed to fetch feed:", "expected our wrapped error message") +} + +func filterEnv(env []string, key string) []string { + prefix := key + "=" + filtered := make([]string, 0, len(env)) + for _, e := range env { + if !strings.HasPrefix(e, prefix) { + filtered = append(filtered, e) + } + } + return filtered } func extractFirstID(t *testing.T, output string) string { diff --git a/internal/rss/rss.go b/internal/rss/rss.go index fefa98f..925993f 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -85,11 +85,11 @@ func (f *Fetcher) ParseFeed(ctx context.Context, feedURL string) ([]FeedArticle, func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, blogURL, nil) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } response, err := f.client.Do(req) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } defer func() { if err := response.Body.Close(); err != nil { @@ -97,6 +97,7 @@ func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, } }() if response.StatusCode < 200 || response.StatusCode >= 300 { + // Not-found / bad status is not an error — just means no feed at this URL. return "", nil } @@ -112,12 +113,12 @@ func (f *Fetcher) DiscoverFeedURL(ctx context.Context, blogURL string) (string, base, err := url.Parse(blogURL) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: %w", err) } doc, err := goquery.NewDocumentFromReader(response.Body) if err != nil { - return "", nil + return "", fmt.Errorf("discover feed: parse HTML: %w", err) } feedTypes := []string{