diff --git a/.golangci.yml b/.golangci.yml index 017f601..7962c77 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -458,6 +458,8 @@ linters: linters: [revive, staticcheck] - text: 'shadow: declaration of "err" shadows declaration' linters: ["govet"] + - text: 'shadow: declaration of "logger" shadows declaration' + linters: ["govet"] - path: '_test\.go' linters: - bodyclose diff --git a/cmd/sfptcd/main.go b/cmd/sfptcd/main.go index a75e0b9..b20e77b 100644 --- a/cmd/sfptcd/main.go +++ b/cmd/sfptcd/main.go @@ -3,42 +3,62 @@ package main import ( "context" "log/slog" + "net" "net/http" "os" + "strings" "time" "github.com/alecthomas/kong" "github.com/block/sfptc/internal/config" + "github.com/block/sfptc/internal/httputil" "github.com/block/sfptc/internal/logging" ) var cli struct { - Config *os.File `hcl:"-" help:"Configuration file path." placeholder:"PATH" required:""` + Config *os.File `hcl:"-" help:"Configuration file path." placeholder:"PATH" required:"" default:"sfptc.hcl"` Bind string `hcl:"bind" default:"127.0.0.1:8080" help:"Bind address for the server."` LoggingConfig logging.Config `embed:"" prefix:"log-"` } func main() { - kctx := kong.Parse(&cli) + kctx := kong.Parse(&cli, kong.DefaultEnvars("SFPTC")) ctx := context.Background() logger, ctx := logging.Configure(ctx, cli.LoggingConfig) mux := http.NewServeMux() - err := config.Load(ctx, cli.Config, mux) + err := config.Load(ctx, cli.Config, mux, parseEnvars()) kctx.FatalIfErrorf(err) logger.InfoContext(ctx, "Starting sfptcd", slog.String("bind", cli.Bind)) server := &http.Server{ Addr: cli.Bind, - Handler: mux, + Handler: httputil.LoggingMiddleware(mux), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, ReadHeaderTimeout: 10 * time.Second, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return logging.ContextWithLogger(ctx, logger.With("client", c.RemoteAddr().String())) + }, } + err = server.ListenAndServe() kctx.FatalIfErrorf(err) } + +func parseEnvars() map[string]string { + envars := map[string]string{} + for _, env := range os.Environ() { + if key, value, ok := strings.Cut(env, "="); ok { + envars[key] = value + } + } + return envars +} diff --git a/internal/cache/http.go b/internal/cache/http.go index 9d2e1bb..61ee00c 100644 --- a/internal/cache/http.go +++ b/internal/cache/http.go @@ -1,7 +1,6 @@ package cache import ( - "fmt" "io" "maps" "net/http" @@ -9,23 +8,9 @@ import ( "os" "github.com/alecthomas/errors" -) - -type HTTPError struct { - status int - err error -} -func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) } -func (h HTTPError) Unwrap() error { return h.err } -func (h HTTPError) StatusCode() int { return h.status } - -func HTTPErrorf(status int, format string, args ...any) error { - return HTTPError{ - status: status, - err: errors.Errorf(format, args...), - } -} + "github.com/block/sfptc/internal/httputil" +) // Fetch retrieves a response from cache or fetches from the request URL and caches it. // The response is streamed without buffering. Returns HTTPError for semantic errors. @@ -49,12 +34,19 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error }, nil } if !errors.Is(err, os.ErrNotExist) { - return nil, HTTPErrorf(http.StatusInternalServerError, "failed to open cache: %w", err) + return nil, httputil.Errorf(http.StatusInternalServerError, "failed to open cache: %w", err) } + return FetchDirect(client, r, c, key) +} + +// FetchDirect fetches and caches the given URL without checking the cache first. +// The response is streamed without buffering. Returns HTTPError for semantic errors. +// The caller must close the response body. +func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http.Response, error) { resp, err := client.Do(r) //nolint:bodyclose // Body is returned to caller if err != nil { - return nil, HTTPErrorf(http.StatusBadGateway, "failed to fetch: %w", err) + return nil, httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err) } if resp.StatusCode != http.StatusOK { @@ -65,7 +57,7 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error cw, err := c.Create(r.Context(), key, responseHeaders, 0) if err != nil { _ = resp.Body.Close() - return nil, HTTPErrorf(http.StatusInternalServerError, "failed to create cache entry: %w", err) + return nil, httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err) } originalBody := resp.Body diff --git a/internal/cache/remote.go b/internal/cache/remote.go index 039e7c6..46e43d7 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -24,7 +24,7 @@ var _ Cache = (*Remote)(nil) // NewRemote creates a new remote cache client. func NewRemote(baseURL string) *Remote { return &Remote{ - baseURL: baseURL, + baseURL: baseURL + "/api/v1/object", client: &http.Client{}, } } diff --git a/internal/cache/remote_test.go b/internal/cache/remote_test.go index fe6b46e..cc9a179 100644 --- a/internal/cache/remote_test.go +++ b/internal/cache/remote_test.go @@ -2,6 +2,7 @@ package cache_test import ( "log/slog" + "net/http" "net/http/httptest" "testing" "time" @@ -24,9 +25,10 @@ func TestRemoteClient(t *testing.T) { assert.NoError(t, err) t.Cleanup(func() { memCache.Close() }) - server, err := strategy.NewDefault(ctx, strategy.DefaultConfig{}, memCache) + mux := http.NewServeMux() + _, err = strategy.NewAPIV1(ctx, struct{}{}, memCache, mux) assert.NoError(t, err) - ts := httptest.NewServer(server) + ts := httptest.NewServer(mux) t.Cleanup(ts.Close) client := cache.NewRemote(ts.URL) diff --git a/internal/config/config.go b/internal/config/config.go index 00b1392..da89c38 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,8 +4,9 @@ package config import ( "context" "io" + "log/slog" "net/http" - "strings" + "os" "github.com/alecthomas/errors" "github.com/alecthomas/hcl/v2" @@ -15,17 +16,36 @@ import ( "github.com/block/sfptc/internal/strategy" ) +type loggingMux struct { + logger *slog.Logger + mux *http.ServeMux +} + +func (l *loggingMux) Handle(pattern string, handler http.Handler) { + l.logger.Debug("Registered strategy handler", "pattern", pattern) + l.mux.Handle(pattern, handler) +} + +func (l *loggingMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { + l.logger.Debug("Registered strategy handler", "pattern", pattern) + l.mux.HandleFunc(pattern, handler) +} + +var _ strategy.Mux = (*loggingMux)(nil) + // Load HCL configuration and uses that to construct the cache backend, and proxy strategies. -func Load(ctx context.Context, r io.Reader, mux *http.ServeMux) error { +func Load(ctx context.Context, r io.Reader, mux *http.ServeMux, vars map[string]string) error { logger := logging.FromContext(ctx) ast, err := hcl.Parse(r) if err != nil { return errors.WithStack(err) } + expandVars(ast, vars) + strategyCandidates := []*hcl.Block{ - // Always enable the default strategy - {Name: "default", Labels: []string{"/api/v1/"}}, + // Always enable the default API strategy + {Name: "apiv1"}, } // First pass, instantiate caches @@ -56,19 +76,27 @@ func Load(ctx context.Context, r io.Reader, mux *http.ServeMux) error { // Second pass, instantiate strategies and bind them to the mux. for _, block := range strategyCandidates { - if len(block.Labels) != 1 { - return errors.Errorf("%s: block must have exactly one label defining the server mount point", block.Pos) - } - pattern := block.Labels[0] - block.Labels = nil - s, err := strategy.Create(ctx, block.Name, block, cache) + logger := logger.With("strategy", block.Name) + mlog := &loggingMux{logger: logger, mux: mux} + _, err := strategy.Create(ctx, block.Name, block, cache, mlog) if err != nil { return errors.Errorf("%s: %w", block.Pos, err) } - - logger.DebugContext(ctx, "Adding strategy", "strategy", s, "pattern", pattern) - - mux.Handle(pattern, http.StripPrefix(strings.TrimSuffix(pattern, "/"), s)) } return nil } + +func expandVars(ast *hcl.AST, vars map[string]string) { + _ = hcl.Visit(ast, func(node hcl.Node, next func() error) error { + attr, ok := node.(*hcl.Attribute) + if ok { + switch attr := attr.Value.(type) { + case *hcl.String: + attr.Str = os.Expand(attr.Str, func(s string) string { return vars[s] }) + case *hcl.Heredoc: + attr.Doc = os.Expand(attr.Doc, func(s string) string { return vars[s] }) + } + } + return next() + }) +} diff --git a/internal/httputil/error.go b/internal/httputil/error.go new file mode 100644 index 0000000..2187e5b --- /dev/null +++ b/internal/httputil/error.go @@ -0,0 +1,34 @@ +// Package httputil contains utilities for HTTP clients and servers. +package httputil + +import ( + "fmt" + "net/http" + + "github.com/alecthomas/errors" + + "github.com/block/sfptc/internal/logging" +) + +// ErrorResponse creates an error response with the given code and format, and also logs a message. +func ErrorResponse(w http.ResponseWriter, r *http.Request, status int, msg string, args ...any) { + logger := logging.FromContext(r.Context()).With("url", r.URL, "status", status) + logger.ErrorContext(r.Context(), msg, args...) + http.Error(w, msg, status) +} + +type HTTPError struct { + status int + err error +} + +func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) } +func (h HTTPError) Unwrap() error { return h.err } +func (h HTTPError) StatusCode() int { return h.status } + +func Errorf(status int, format string, args ...any) error { + return HTTPError{ + status: status, + err: errors.Errorf(format, args...), + } +} diff --git a/internal/httputil/logging.go b/internal/httputil/logging.go new file mode 100644 index 0000000..8266175 --- /dev/null +++ b/internal/httputil/logging.go @@ -0,0 +1,16 @@ +package httputil + +import ( + "net/http" + + "github.com/block/sfptc/internal/logging" +) + +func LoggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger := logging.FromContext(r.Context()).With("url", r.RequestURI) + r = r.WithContext(logging.ContextWithLogger(r.Context(), logger)) + logger.Debug("Request received") + next.ServeHTTP(w, r) + }) +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 385c00e..7ea5e3a 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -42,3 +42,8 @@ func FromContext(ctx context.Context) *slog.Logger { } return logger } + +// ContextWithLogger returns a new context with the given logger. +func ContextWithLogger(ctx context.Context, logger *slog.Logger) context.Context { + return context.WithValue(ctx, logKey{}, logger) +} diff --git a/internal/strategy/api.go b/internal/strategy/api.go index 870e341..599e6ba 100644 --- a/internal/strategy/api.go +++ b/internal/strategy/api.go @@ -14,32 +14,36 @@ import ( // ErrNotFound is returned when a strategy is not found. var ErrNotFound = errors.New("strategy not found") -var registry = map[string]func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error){} +type Mux interface { + Handle(pattern string, handler http.Handler) + HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) +} + +var registry = map[string]func(ctx context.Context, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error){} -type Factory[Config any, S Strategy] func(ctx context.Context, config Config, cache cache.Cache) (S, error) +type Factory[Config any, S Strategy] func(ctx context.Context, config Config, cache cache.Cache, mux Mux) (S, error) // Register a new proxy strategy. func Register[Config any, S Strategy](id string, factory Factory[Config, S]) { - registry[id] = func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error) { + registry[id] = func(ctx context.Context, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error) { var cfg Config if err := hcl.UnmarshalBlock(config, &cfg, hcl.AllowExtra(false)); err != nil { return nil, errors.WithStack(err) } - return factory(ctx, cfg, cache) + return factory(ctx, cfg, cache, mux) } } // Create a new proxy strategy. // // Will return "ErrNotFound" if the strategy is not found. -func Create(ctx context.Context, name string, config *hcl.Block, cache cache.Cache) (Strategy, error) { +func Create(ctx context.Context, name string, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error) { if factory, ok := registry[name]; ok { - return errors.WithStack2(factory(ctx, config, cache)) + return errors.WithStack2(factory(ctx, config, cache, mux)) } return nil, errors.Errorf("%s: %w", name, ErrNotFound) } type Strategy interface { String() string - http.Handler } diff --git a/internal/strategy/default.go b/internal/strategy/apiv1.go similarity index 73% rename from internal/strategy/default.go rename to internal/strategy/apiv1.go index 9958945..9a9cc75 100644 --- a/internal/strategy/default.go +++ b/internal/strategy/apiv1.go @@ -16,41 +16,31 @@ import ( ) func init() { - Register("default", NewDefault) + Register("apiv1", NewAPIV1) } -type DefaultConfig struct{} +var _ Strategy = (*APIV1)(nil) -var _ Strategy = (*Default)(nil) - -// The Default strategy represents v1 of the proxy API. -type Default struct { +// The APIV1 strategy represents v1 of the proxy API. +type APIV1 struct { cache cache.Cache logger *slog.Logger - mux *http.ServeMux } -var _ http.Handler = (*Default)(nil) - -func NewDefault(ctx context.Context, _ DefaultConfig, cache cache.Cache) (*Default, error) { - s := &Default{ +func NewAPIV1(ctx context.Context, _ struct{}, cache cache.Cache, mux Mux) (*APIV1, error) { + s := &APIV1{ logger: logging.FromContext(ctx), cache: cache, - mux: http.NewServeMux(), } - s.mux.Handle("GET /{key}", http.HandlerFunc(s.getObject)) - s.mux.Handle("POST /{key}", http.HandlerFunc(s.putObject)) - s.mux.Handle("DELETE /{key}", http.HandlerFunc(s.deleteObject)) + mux.Handle("GET /api/v1/object/{key}", http.HandlerFunc(s.getObject)) + mux.Handle("POST /api/v1/object/{key}", http.HandlerFunc(s.putObject)) + mux.Handle("DELETE /api/v1/object/{key}", http.HandlerFunc(s.deleteObject)) return s, nil } -func (d *Default) String() string { return "default" } - -func (d *Default) ServeHTTP(w http.ResponseWriter, r *http.Request) { - d.mux.ServeHTTP(w, r) -} +func (d *APIV1) String() string { return "default" } -func (d *Default) getObject(w http.ResponseWriter, r *http.Request) { +func (d *APIV1) getObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { d.httpError(w, http.StatusBadRequest, err, "Invalid key") @@ -78,7 +68,7 @@ func (d *Default) getObject(w http.ResponseWriter, r *http.Request) { } } -func (d *Default) putObject(w http.ResponseWriter, r *http.Request) { +func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { d.httpError(w, http.StatusBadRequest, err, "Invalid key") @@ -115,7 +105,7 @@ func (d *Default) putObject(w http.ResponseWriter, r *http.Request) { } } -func (d *Default) deleteObject(w http.ResponseWriter, r *http.Request) { +func (d *APIV1) deleteObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { d.httpError(w, http.StatusBadRequest, err, "Invalid key") @@ -133,7 +123,7 @@ func (d *Default) deleteObject(w http.ResponseWriter, r *http.Request) { } } -func (d *Default) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { +func (d *APIV1) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { args = append(args, slog.String("error", err.Error())) d.logger.Error(message, args...) http.Error(w, message, code) diff --git a/internal/strategy/github_releases.go b/internal/strategy/github_releases.go new file mode 100644 index 0000000..9df235c --- /dev/null +++ b/internal/strategy/github_releases.go @@ -0,0 +1,204 @@ +package strategy + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "os" + "slices" + + "github.com/alecthomas/errors" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/httputil" + "github.com/block/sfptc/internal/logging" +) + +func init() { + Register("github-releases", NewGitHubReleases) +} + +type GitHubReleasesConfig struct { + Token string `hcl:"token" help:"GitHub token for authentication."` + PrivateOrgs []string `hcl:"private-orgs" help:"List of private GitHub organisations."` +} + +// The GitHubReleases strategy fetches private (and public) release binaries from GitHub. +type GitHubReleases struct { + config GitHubReleasesConfig + cache cache.Cache +} + +// NewGitHubReleases creates a [Strategy] that fetches private (and public) release binaries from GitHub. +func NewGitHubReleases(ctx context.Context, config GitHubReleasesConfig, cache cache.Cache, mux Mux) (*GitHubReleases, error) { + s := &GitHubReleases{ + config: config, + cache: cache, + } + logger := logging.FromContext(ctx) + if config.Token == "" { + logger.WarnContext(ctx, "No token configured for github-releases strategy") + } + // eg. https://github.com/alecthomas/chroma/releases/download/v2.21.1/chroma-2.21.1-darwin-amd64.tar.gz + mux.Handle("GET /github.com/{org}/{repo}/releases/download/{release}/{file}", http.HandlerFunc(s.fetch)) + return s, nil +} + +var _ Strategy = (*GitHubReleases)(nil) + +func (g *GitHubReleases) String() string { return "github-releases" } + +func (g *GitHubReleases) fetch(w http.ResponseWriter, r *http.Request) { + org := r.PathValue("org") + repo := r.PathValue("repo") + release := r.PathValue("release") + file := r.PathValue("file") + ghURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", org, repo, release, file) + + logger := logging.FromContext(r.Context()).With("upstream", ghURL) + + key := cache.NewKey(ghURL) + + logger.Debug("Fetching GitHub release") + + // Check if the key exists in the cache + cr, headers, err := g.cache.Open(r.Context(), key) + if err == nil { + logger.Debug("Cache hit") + // Cache hit - stream directly from cache + defer cr.Close() + maps.Copy(w.Header(), headers) + if _, err := io.Copy(w, cr); err != nil { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to stream from cache", "error", err.Error()) + return + } + return + } + if !errors.Is(err, os.ErrNotExist) { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to open cache", "error", err.Error()) + return + } + + // Cache miss - fetch from GitHub and stream while caching + req, err := g.downloadRelease(r.Context(), org, repo, release, file) + if err != nil { + if herr, ok := errors.AsType[httputil.HTTPError](err); ok { + httputil.ErrorResponse(w, r, herr.StatusCode(), herr.Error(), "upstream", ghURL) + } else { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to create download request", "error", err.Error()) + } + return + } + + response, err := cache.FetchDirect(http.DefaultClient, req, g.cache, key) + if err != nil { + if herr, ok := errors.AsType[httputil.HTTPError](err); ok { + httputil.ErrorResponse(w, r, herr.StatusCode(), herr.Error()) + } else { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, err.Error()) + } + return + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + httputil.ErrorResponse(w, r, response.StatusCode, response.Status) + return + } + maps.Copy(w.Header(), response.Header) + if _, err := io.Copy(w, response.Body); err != nil { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to stream response", "error", err.Error()) + return + } +} + +// newGitHubRequest creates a new HTTP request with GitHub API headers and authentication. +func (g *GitHubReleases) newGitHubRequest(ctx context.Context, url, accept string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, errors.Wrap(err, "create request") + } + req.Header.Set("Accept", accept) + req.Header.Set("X-Github-Api-Version", "2022-11-28") + if g.config.Token != "" { + req.Header.Set("Authorization", "Bearer "+g.config.Token) + } + return req, nil +} + +// downloadRelease creates an HTTP request to download a GitHub release asset. +// For private orgs, it uses the GitHub API to find and download the asset. +// For public orgs, it constructs a direct download URL. +func (g *GitHubReleases) downloadRelease(ctx context.Context, org, repo, release, file string) (*http.Request, error) { + isPrivate := slices.Contains(g.config.PrivateOrgs, org) + + logger := logging.FromContext(ctx).With( + slog.String("org", org), + slog.String("repo", repo), + slog.String("release", release), + slog.String("file", file)) + + realURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", org, repo, release, file) + if !isPrivate { + // Public release - use direct download URL + logger.DebugContext(ctx, "Using public download URL") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, realURL, nil) + if err != nil { + return nil, httputil.Errorf(http.StatusInternalServerError, "create download request") + } + return req, nil + } + + // Use GitHub API to get release info and find the asset + logger.DebugContext(ctx, "Using GitHub API for private release") + apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/tags/%s", org, repo, release) + req, err := g.newGitHubRequest(ctx, apiURL, "application/vnd.github+json") + if err != nil { + return nil, httputil.Errorf(http.StatusInternalServerError, "create API request") + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, httputil.Errorf(http.StatusBadGateway, "fetch release info failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, httputil.Errorf(resp.StatusCode, "GitHub API returned %d", resp.StatusCode) + } + + var releaseInfo struct { + Assets []struct { + Name string `json:"name"` + URL string `json:"url"` + } `json:"assets"` + } + if err := json.NewDecoder(resp.Body).Decode(&releaseInfo); err != nil { + return nil, httputil.Errorf(http.StatusBadGateway, "decode release info failed: %w", err) + } + + // Find the matching asset + var assetURL string + for _, asset := range releaseInfo.Assets { + if asset.Name == file { + assetURL = asset.URL + break + } + } + if assetURL == "" { + logger.ErrorContext(ctx, "Asset not found in release", slog.Int("assets_count", len(releaseInfo.Assets))) + return nil, httputil.Errorf(http.StatusNotFound, "asset %s not found in release %s", file, release) + } + + logger.DebugContext(ctx, "Found asset in release", slog.String("asset_url", assetURL)) + + // Create request for the asset download + req, err = g.newGitHubRequest(ctx, assetURL, "application/octet-stream") + if err != nil { + return nil, httputil.Errorf(http.StatusInternalServerError, "create asset request failed: %w", err) + } + return req, nil +} diff --git a/internal/strategy/host.go b/internal/strategy/host.go index fc17f11..efff241 100644 --- a/internal/strategy/host.go +++ b/internal/strategy/host.go @@ -12,6 +12,7 @@ import ( "github.com/alecthomas/errors" "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/httputil" "github.com/block/sfptc/internal/logging" ) @@ -23,13 +24,13 @@ func init() { // // In HCL it looks something like this: // -// host "/github/" { +// host { // target = "https://github.com/" // } // -// In this example, the strategy will be mounted under "/github". +// In this example, the strategy will be mounted under "/github.com". type HostConfig struct { - Target string `hcl:"target" help:"The target URL to proxy requests to."` + Target string `hcl:"target,label" help:"The target URL to proxy requests to."` } // The Host [Strategy] forwards all GET requests to the specified host, caching the response payloads. @@ -38,48 +39,66 @@ type Host struct { cache cache.Cache client *http.Client logger *slog.Logger + prefix string } var _ Strategy = (*Host)(nil) -func NewHost(ctx context.Context, config HostConfig, cache cache.Cache) (*Host, error) { +func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) (*Host, error) { u, err := url.Parse(config.Target) if err != nil { return nil, fmt.Errorf("invalid target URL: %w", err) } - return &Host{ + prefix := "/" + u.Host + u.EscapedPath() + h := &Host{ target: u, cache: cache, client: &http.Client{}, logger: logging.FromContext(ctx), - }, nil + prefix: prefix, + } + mux.HandleFunc("GET "+prefix+"/", h.serveHTTP) + return h, nil } func (d *Host) String() string { return "host:" + d.target.Host + d.target.Path } -func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (d *Host) serveHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } - targetURL := *d.target - targetURL.Path = r.URL.Path + // Strip the prefix from the request path + path := r.URL.Path + if len(path) >= len(d.prefix) { + path = path[len(d.prefix):] + } + if path == "" { + path = "/" + } + + targetURL, err := url.Parse(d.target.String()) + if err != nil { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to parse target URL", "error", err.Error(), "upstream", d.target.String()) + return + } + targetURL.Path = path targetURL.RawQuery = r.URL.RawQuery fullURL := targetURL.String() req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, fullURL, nil) if err != nil { - d.httpError(w, http.StatusInternalServerError, err, "Failed to create request", slog.String("url", fullURL)) + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to create request", "error", err.Error(), "upstream", fullURL) return } resp, err := cache.Fetch(d.client, req, d.cache) if err != nil { - if httpErr, ok := errors.AsType[cache.HTTPError](err); ok { - d.httpError(w, httpErr.StatusCode(), httpErr, httpErr.Error(), slog.String("url", fullURL)) + if httpErr, ok := errors.AsType[httputil.HTTPError](err); ok { + httputil.ErrorResponse(w, r, httpErr.StatusCode(), httpErr.Error(), "error", httpErr.Error(), "upstream", fullURL) } else { - d.httpError(w, http.StatusInternalServerError, err, "Failed to fetch", slog.String("url", fullURL)) + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to fetch", "error", err.Error(), "upstream", fullURL) } return } @@ -88,19 +107,13 @@ func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) { if resp.StatusCode != http.StatusOK { w.WriteHeader(resp.StatusCode) if _, err := io.Copy(w, resp.Body); err != nil { - d.logger.Error("Failed to copy error response", slog.String("error", err.Error()), slog.String("url", fullURL)) + d.logger.Error("Failed to copy error response", "error", err.Error(), "upstream", fullURL) } return } maps.Copy(w.Header(), resp.Header) if _, err := io.Copy(w, resp.Body); err != nil { - d.logger.Error("Failed to copy response", slog.String("error", err.Error()), slog.String("url", fullURL)) + d.logger.Error("Failed to copy response", "error", err.Error(), "upstream", fullURL) } } - -func (d *Host) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { - args = append(args, slog.String("error", err.Error())) - d.logger.Error(message, args...) - http.Error(w, message, code) -} diff --git a/internal/strategy/host_test.go b/internal/strategy/host_test.go index 45a95a3..1ec7eea 100644 --- a/internal/strategy/host_test.go +++ b/internal/strategy/host_test.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -29,20 +30,25 @@ func TestHostCaching(t *testing.T) { assert.NoError(t, err) defer memCache.Close() - host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache) + mux := http.NewServeMux() + _, err = strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache, mux) assert.NoError(t, err) - req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + // Request path must include the host prefix from the target URL + u, _ := url.Parse(backend.URL) + reqPath := "/" + u.Host + "/test" + + req1 := httptest.NewRequest(http.MethodGet, reqPath, nil) w1 := httptest.NewRecorder() - host.ServeHTTP(w1, req1) + mux.ServeHTTP(w1, req1) assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "response", w1.Body.String()) assert.Equal(t, 1, callCount) - req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2 := httptest.NewRequest(http.MethodGet, reqPath, nil) w2 := httptest.NewRecorder() - host.ServeHTTP(w2, req2) + mux.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "response", w2.Body.String()) @@ -61,12 +67,17 @@ func TestHostNonOKStatus(t *testing.T) { assert.NoError(t, err) defer memCache.Close() - host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache) + mux := http.NewServeMux() + _, err = strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache, mux) assert.NoError(t, err) - req := httptest.NewRequest(http.MethodGet, "/missing", nil) + // Request path must include the host prefix from the target URL + u, _ := url.Parse(backend.URL) + reqPath := "/" + u.Host + "/missing" + + req := httptest.NewRequest(http.MethodGet, reqPath, nil) w := httptest.NewRecorder() - host.ServeHTTP(w, req) + mux.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) assert.Equal(t, "not found", w.Body.String()) @@ -82,7 +93,8 @@ func TestHostInvalidTargetURL(t *testing.T) { assert.NoError(t, err) defer memCache.Close() - _, err = strategy.NewHost(ctx, strategy.HostConfig{Target: "://invalid"}, memCache) + mux := http.NewServeMux() + _, err = strategy.NewHost(ctx, strategy.HostConfig{Target: "://invalid"}, memCache, mux) assert.Error(t, err) } @@ -92,7 +104,8 @@ func TestHostString(t *testing.T) { assert.NoError(t, err) defer memCache.Close() - host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: "https://example.com/prefix"}, memCache) + mux := http.NewServeMux() + host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: "https://example.com/prefix"}, memCache, mux) assert.NoError(t, err) assert.Equal(t, "host:example.com/prefix", host.String()) diff --git a/sfptc.hcl b/sfptc.hcl index e585e9c..8047789 100644 --- a/sfptc.hcl +++ b/sfptc.hcl @@ -5,8 +5,11 @@ # mitm = ["artifactory.square.com"] # } -host "/github/" { - target = "https://github.com/" +host "https://w3.org" {} + +github-releases { + token = "${GITHUB_TOKEN}" + private-orgs = ["alecthomas"] } disk {