Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ linters:
linters: [revive, staticcheck]
- text: 'shadow: declaration of "err" shadows declaration'
linters: ["govet"]
- text: 'shadow: declaration of "logger" shadows declaration'
linters: ["govet"]
- path: '_test\.go'
linters:
- bodyclose
Expand Down
28 changes: 24 additions & 4 deletions cmd/sfptcd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,62 @@ package main
import (
"context"
"log/slog"
"net"
"net/http"
"os"
"strings"
"time"

"github.com/alecthomas/kong"

"github.com/block/sfptc/internal/config"
"github.com/block/sfptc/internal/httputil"
"github.com/block/sfptc/internal/logging"
)

var cli struct {
Config *os.File `hcl:"-" help:"Configuration file path." placeholder:"PATH" required:""`
Config *os.File `hcl:"-" help:"Configuration file path." placeholder:"PATH" required:"" default:"sfptc.hcl"`
Bind string `hcl:"bind" default:"127.0.0.1:8080" help:"Bind address for the server."`
LoggingConfig logging.Config `embed:"" prefix:"log-"`
}

func main() {
kctx := kong.Parse(&cli)
kctx := kong.Parse(&cli, kong.DefaultEnvars("SFPTC"))

ctx := context.Background()
logger, ctx := logging.Configure(ctx, cli.LoggingConfig)

mux := http.NewServeMux()

err := config.Load(ctx, cli.Config, mux)
err := config.Load(ctx, cli.Config, mux, parseEnvars())
kctx.FatalIfErrorf(err)

logger.InfoContext(ctx, "Starting sfptcd", slog.String("bind", cli.Bind))

server := &http.Server{
Addr: cli.Bind,
Handler: mux,
Handler: httputil.LoggingMiddleware(mux),
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
ReadHeaderTimeout: 10 * time.Second,
BaseContext: func(net.Listener) context.Context {
return ctx
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return logging.ContextWithLogger(ctx, logger.With("client", c.RemoteAddr().String()))
},
}

err = server.ListenAndServe()
kctx.FatalIfErrorf(err)
}

func parseEnvars() map[string]string {
envars := map[string]string{}
for _, env := range os.Environ() {
if key, value, ok := strings.Cut(env, "="); ok {
envars[key] = value
}
}
return envars
}
32 changes: 12 additions & 20 deletions internal/cache/http.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
package cache

import (
"fmt"
"io"
"maps"
"net/http"
"net/textproto"
"os"

"github.com/alecthomas/errors"
)

type HTTPError struct {
status int
err error
}

func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) }
func (h HTTPError) Unwrap() error { return h.err }
func (h HTTPError) StatusCode() int { return h.status }

func HTTPErrorf(status int, format string, args ...any) error {
return HTTPError{
status: status,
err: errors.Errorf(format, args...),
}
}
"github.com/block/sfptc/internal/httputil"
)

// Fetch retrieves a response from cache or fetches from the request URL and caches it.
// The response is streamed without buffering. Returns HTTPError for semantic errors.
Expand All @@ -49,12 +34,19 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error
}, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, HTTPErrorf(http.StatusInternalServerError, "failed to open cache: %w", err)
return nil, httputil.Errorf(http.StatusInternalServerError, "failed to open cache: %w", err)
}

return FetchDirect(client, r, c, key)
}

// FetchDirect fetches and caches the given URL without checking the cache first.
// The response is streamed without buffering. Returns HTTPError for semantic errors.
// The caller must close the response body.
func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http.Response, error) {
resp, err := client.Do(r) //nolint:bodyclose // Body is returned to caller
if err != nil {
return nil, HTTPErrorf(http.StatusBadGateway, "failed to fetch: %w", err)
return nil, httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err)
}

if resp.StatusCode != http.StatusOK {
Expand All @@ -65,7 +57,7 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error
cw, err := c.Create(r.Context(), key, responseHeaders, 0)
if err != nil {
_ = resp.Body.Close()
return nil, HTTPErrorf(http.StatusInternalServerError, "failed to create cache entry: %w", err)
return nil, httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err)
}

originalBody := resp.Body
Expand Down
2 changes: 1 addition & 1 deletion internal/cache/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ var _ Cache = (*Remote)(nil)
// NewRemote creates a new remote cache client.
func NewRemote(baseURL string) *Remote {
return &Remote{
baseURL: baseURL,
baseURL: baseURL + "/api/v1/object",
client: &http.Client{},
}
}
Expand Down
6 changes: 4 additions & 2 deletions internal/cache/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cache_test

import (
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
Expand All @@ -24,9 +25,10 @@ func TestRemoteClient(t *testing.T) {
assert.NoError(t, err)
t.Cleanup(func() { memCache.Close() })

server, err := strategy.NewDefault(ctx, strategy.DefaultConfig{}, memCache)
mux := http.NewServeMux()
_, err = strategy.NewAPIV1(ctx, struct{}{}, memCache, mux)
assert.NoError(t, err)
ts := httptest.NewServer(server)
ts := httptest.NewServer(mux)
t.Cleanup(ts.Close)

client := cache.NewRemote(ts.URL)
Expand Down
56 changes: 42 additions & 14 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ package config
import (
"context"
"io"
"log/slog"
"net/http"
"strings"
"os"

"github.com/alecthomas/errors"
"github.com/alecthomas/hcl/v2"
Expand All @@ -15,17 +16,36 @@ import (
"github.com/block/sfptc/internal/strategy"
)

type loggingMux struct {
logger *slog.Logger
mux *http.ServeMux
}

func (l *loggingMux) Handle(pattern string, handler http.Handler) {
l.logger.Debug("Registered strategy handler", "pattern", pattern)
l.mux.Handle(pattern, handler)
}

func (l *loggingMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
l.logger.Debug("Registered strategy handler", "pattern", pattern)
l.mux.HandleFunc(pattern, handler)
}

var _ strategy.Mux = (*loggingMux)(nil)

// Load HCL configuration and uses that to construct the cache backend, and proxy strategies.
func Load(ctx context.Context, r io.Reader, mux *http.ServeMux) error {
func Load(ctx context.Context, r io.Reader, mux *http.ServeMux, vars map[string]string) error {
logger := logging.FromContext(ctx)
ast, err := hcl.Parse(r)
if err != nil {
return errors.WithStack(err)
}

expandVars(ast, vars)

strategyCandidates := []*hcl.Block{
// Always enable the default strategy
{Name: "default", Labels: []string{"/api/v1/"}},
// Always enable the default API strategy
{Name: "apiv1"},
}

// First pass, instantiate caches
Expand Down Expand Up @@ -56,19 +76,27 @@ func Load(ctx context.Context, r io.Reader, mux *http.ServeMux) error {

// Second pass, instantiate strategies and bind them to the mux.
for _, block := range strategyCandidates {
if len(block.Labels) != 1 {
return errors.Errorf("%s: block must have exactly one label defining the server mount point", block.Pos)
}
pattern := block.Labels[0]
block.Labels = nil
s, err := strategy.Create(ctx, block.Name, block, cache)
logger := logger.With("strategy", block.Name)
mlog := &loggingMux{logger: logger, mux: mux}
_, err := strategy.Create(ctx, block.Name, block, cache, mlog)
if err != nil {
return errors.Errorf("%s: %w", block.Pos, err)
}

logger.DebugContext(ctx, "Adding strategy", "strategy", s, "pattern", pattern)

mux.Handle(pattern, http.StripPrefix(strings.TrimSuffix(pattern, "/"), s))
}
return nil
}

func expandVars(ast *hcl.AST, vars map[string]string) {
_ = hcl.Visit(ast, func(node hcl.Node, next func() error) error {
attr, ok := node.(*hcl.Attribute)
if ok {
switch attr := attr.Value.(type) {
case *hcl.String:
attr.Str = os.Expand(attr.Str, func(s string) string { return vars[s] })
case *hcl.Heredoc:
attr.Doc = os.Expand(attr.Doc, func(s string) string { return vars[s] })
}
}
return next()
})
}
34 changes: 34 additions & 0 deletions internal/httputil/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Package httputil contains utilities for HTTP clients and servers.
package httputil

import (
"fmt"
"net/http"

"github.com/alecthomas/errors"

"github.com/block/sfptc/internal/logging"
)

// ErrorResponse creates an error response with the given code and format, and also logs a message.
func ErrorResponse(w http.ResponseWriter, r *http.Request, status int, msg string, args ...any) {
logger := logging.FromContext(r.Context()).With("url", r.URL, "status", status)
logger.ErrorContext(r.Context(), msg, args...)
http.Error(w, msg, status)
}

type HTTPError struct {
status int
err error
}

func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) }
func (h HTTPError) Unwrap() error { return h.err }
func (h HTTPError) StatusCode() int { return h.status }

func Errorf(status int, format string, args ...any) error {
return HTTPError{
status: status,
err: errors.Errorf(format, args...),
}
}
16 changes: 16 additions & 0 deletions internal/httputil/logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package httputil

import (
"net/http"

"github.com/block/sfptc/internal/logging"
)

func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := logging.FromContext(r.Context()).With("url", r.RequestURI)
r = r.WithContext(logging.ContextWithLogger(r.Context(), logger))
logger.Debug("Request received")
next.ServeHTTP(w, r)
})
}
5 changes: 5 additions & 0 deletions internal/logging/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ func FromContext(ctx context.Context) *slog.Logger {
}
return logger
}

// ContextWithLogger returns a new context with the given logger.
func ContextWithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, logKey{}, logger)
}
18 changes: 11 additions & 7 deletions internal/strategy/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,36 @@ import (
// ErrNotFound is returned when a strategy is not found.
var ErrNotFound = errors.New("strategy not found")

var registry = map[string]func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error){}
type Mux interface {
Handle(pattern string, handler http.Handler)
HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request))
}

var registry = map[string]func(ctx context.Context, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error){}

type Factory[Config any, S Strategy] func(ctx context.Context, config Config, cache cache.Cache) (S, error)
type Factory[Config any, S Strategy] func(ctx context.Context, config Config, cache cache.Cache, mux Mux) (S, error)

// Register a new proxy strategy.
func Register[Config any, S Strategy](id string, factory Factory[Config, S]) {
registry[id] = func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error) {
registry[id] = func(ctx context.Context, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error) {
var cfg Config
if err := hcl.UnmarshalBlock(config, &cfg, hcl.AllowExtra(false)); err != nil {
return nil, errors.WithStack(err)
}
return factory(ctx, cfg, cache)
return factory(ctx, cfg, cache, mux)
}
}

// Create a new proxy strategy.
//
// Will return "ErrNotFound" if the strategy is not found.
func Create(ctx context.Context, name string, config *hcl.Block, cache cache.Cache) (Strategy, error) {
func Create(ctx context.Context, name string, config *hcl.Block, cache cache.Cache, mux Mux) (Strategy, error) {
if factory, ok := registry[name]; ok {
return errors.WithStack2(factory(ctx, config, cache))
return errors.WithStack2(factory(ctx, config, cache, mux))
}
return nil, errors.Errorf("%s: %w", name, ErrNotFound)
}

type Strategy interface {
String() string
http.Handler
}
Loading