diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..98669b8 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,141 @@ +// Package middleware provides HTTP middleware for transparent L402 payment handling. +// +// The middleware intercepts outgoing HTTP requests and automatically handles L402 +// payment challenges (HTTP 402 responses). When a proxied request returns 402, +// the middleware pays the Lightning invoice using the configured wallet, then +// retries the request with the L402 token. +package middleware + +import ( + "io" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/sulusolutions/gol402/client" + "github.com/sulusolutions/gol402/tokenstore" + "github.com/sulusolutions/gol402/wallet" +) + +// Config holds configuration for the L402 middleware. +type Config struct { + // Wallet handles Lightning invoice payments. + Wallet wallet.Wallet + // Store persists L402 tokens for reuse across requests. + Store tokenstore.Store +} + +// L402 returns HTTP middleware that transparently handles L402 payment challenges. +// It wraps the next handler, intercepting responses. If the upstream returns +// HTTP 402 with a WWW-Authenticate L402 challenge, the middleware pays the invoice +// via the configured wallet and retries the request with the L402 token. +// +// The middleware follows the standard Go pattern: it takes an http.Handler and +// returns a new http.Handler. +func L402(cfg Config) func(http.Handler) http.Handler { + store := cfg.Store + if store == nil { + store = tokenstore.NewNoopStore() + } + l402Client := client.New(cfg.Wallet, store) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract the target URL from the request. + // In proxy mode, the full URL is in r.URL or X-Target-URL header. + targetURL := r.Header.Get("X-Target-URL") + if targetURL == "" { + // Not a proxy request — pass through to next handler. + next.ServeHTTP(w, r) + return + } + + parsed, err := url.Parse(targetURL) + if err != nil { + http.Error(w, "invalid X-Target-URL: "+err.Error(), http.StatusBadRequest) + return + } + + // Build the upstream request. + upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, parsed.String(), r.Body) + if err != nil { + http.Error(w, "failed to build upstream request: "+err.Error(), http.StatusInternalServerError) + return + } + + // Copy relevant headers from the original request. + copyHeaders(upstreamReq.Header, r.Header) + upstreamReq.Header.Del("X-Target-URL") // Don't forward the routing header. + + // Use the L402 client to make the request. It handles 402 challenges automatically. + resp, err := l402Client.Do(upstreamReq) + if err != nil { + http.Error(w, "upstream request failed: "+err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Copy the upstream response back to the client. + copyHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + }) + } +} + +// ReverseProxy returns a reverse proxy handler that transparently handles L402 +// payment challenges when proxying requests to the given target URL. +func ReverseProxy(target *url.URL, cfg Config) http.Handler { + store := cfg.Store + if store == nil { + store = tokenstore.NewNoopStore() + } + l402Client := client.New(cfg.Wallet, store) + + proxy := httputil.NewSingleHostReverseProxy(target) + + // Wrap the proxy transport to use the L402 client. + proxy.Transport = &l402Transport{ + client: l402Client, + targetURL: target, + } + + return proxy +} + +// l402Transport implements http.RoundTripper, using the L402 client +// to handle payment challenges during proxied requests. +type l402Transport struct { + client *client.Client + targetURL *url.URL +} + +func (t *l402Transport) RoundTrip(req *http.Request) (*http.Response, error) { + // The reverse proxy passes requests with RequestURI set, but + // http.Client.Do() rejects such requests. Clear it before delegating. + req.RequestURI = "" + return t.client.Do(req) +} + +// copyHeaders copies HTTP headers from src to dst, skipping hop-by-hop headers. +func copyHeaders(dst, src http.Header) { + hopByHop := map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "Te": true, + "Trailers": true, + "Transfer-Encoding": true, + "Upgrade": true, + } + + for key, values := range src { + if hopByHop[key] { + continue + } + for _, v := range values { + dst.Add(key, v) + } + } +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..680dcad --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,307 @@ +package middleware + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/sulusolutions/gol402/tokenstore" + "github.com/sulusolutions/gol402/wallet" +) + +func TestL402Middleware_PassThrough(t *testing.T) { + // Requests without X-Target-URL should pass through to the next handler. + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("inner handler")) + }) + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + })(inner) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "inner handler" { + t.Errorf("expected inner handler response, got %q", rec.Body.String()) + } +} + +func TestL402Middleware_ProxyNon402(t *testing.T) { + // Upstream returns 200 directly — no payment needed. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom", "value") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer upstream.Close() + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + })(http.NotFoundHandler()) + + req := httptest.NewRequest("GET", "/proxy", nil) + req.Header.Set("X-Target-URL", upstream.URL+"/api") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != `{"status":"ok"}` { + t.Errorf("unexpected body: %s", rec.Body.String()) + } + if rec.Header().Get("X-Custom") != "value" { + t.Error("upstream headers not copied") + } +} + +func TestL402Middleware_Proxy402Success(t *testing.T) { + // Upstream returns 402, middleware pays and retries, gets 200. + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if r.Header.Get("Authorization") != "" && strings.HasPrefix(r.Header.Get("Authorization"), "L402") { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":"paid content"}`)) + return + } + w.Header().Set("WWW-Authenticate", `L402 macaroon="testMac123", invoice="lnbc100test"`) + w.WriteHeader(http.StatusPaymentRequired) + })) + defer upstream.Close() + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + })(http.NotFoundHandler()) + + req := httptest.NewRequest("GET", "/proxy", nil) + req.Header.Set("X-Target-URL", upstream.URL+"/paid") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200 after payment, got %d", rec.Code) + } + if rec.Body.String() != `{"result":"paid content"}` { + t.Errorf("unexpected body: %s", rec.Body.String()) + } + if callCount != 2 { + t.Errorf("expected 2 upstream calls (initial + retry), got %d", callCount) + } +} + +func TestL402Middleware_Proxy402PaymentFailure(t *testing.T) { + // Upstream returns 402, wallet payment fails. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `L402 macaroon="testMac", invoice="lnbc100test"`) + w.WriteHeader(http.StatusPaymentRequired) + })) + defer upstream.Close() + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(http.ErrAbortHandler), + Store: tokenstore.NewNoopStore(), + })(http.NotFoundHandler()) + + req := httptest.NewRequest("GET", "/proxy", nil) + req.Header.Set("X-Target-URL", upstream.URL+"/paid") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadGateway { + t.Errorf("expected 502 on payment failure, got %d", rec.Code) + } +} + +func TestL402Middleware_InvalidTargetURL(t *testing.T) { + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + })(http.NotFoundHandler()) + + req := httptest.NewRequest("GET", "/proxy", nil) + req.Header.Set("X-Target-URL", "://invalid") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid URL, got %d", rec.Code) + } +} + +func TestL402Middleware_HeaderCopy(t *testing.T) { + // Verify that client headers (except hop-by-hop and X-Target-URL) are forwarded. + var receivedHeaders http.Header + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + })(http.NotFoundHandler()) + + req := httptest.NewRequest("POST", "/proxy", strings.NewReader("test body")) + req.Header.Set("X-Target-URL", upstream.URL) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Custom-Header", "custom-value") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if receivedHeaders.Get("Content-Type") != "application/json" { + t.Error("Content-Type not forwarded") + } + if receivedHeaders.Get("X-Custom-Header") != "custom-value" { + t.Error("custom header not forwarded") + } + if receivedHeaders.Get("X-Target-URL") != "" { + t.Error("X-Target-URL should not be forwarded to upstream") + } +} + +func TestL402Middleware_TokenReuse(t *testing.T) { + // Verify that the store caches and reuses L402 tokens. + store := tokenstore.NewInMemoryStore() + + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "L402") { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + return + } + w.Header().Set("WWW-Authenticate", `L402 macaroon="mac123", invoice="inv123"`) + w.WriteHeader(http.StatusPaymentRequired) + })) + defer upstream.Close() + + handler := L402(Config{ + Wallet: wallet.NewMockWallet(nil), + Store: store, + })(http.NotFoundHandler()) + + // First request: should trigger 402 + payment + retry = 2 upstream calls. + req1 := httptest.NewRequest("GET", "/proxy", nil) + req1.Header.Set("X-Target-URL", upstream.URL+"/api") + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + + if rec1.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", rec1.Code) + } + if callCount != 2 { + t.Fatalf("first request: expected 2 calls, got %d", callCount) + } + + // Verify token was stored. + u, _ := url.Parse(upstream.URL + "/api") + token, ok := store.Get(u) + if !ok || token == "" { + t.Fatal("expected token to be stored after first payment") + } +} + +func TestReverseProxy_Non402(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("proxied")) + })) + defer upstream.Close() + + u, _ := url.Parse(upstream.URL) + proxy := ReverseProxy(u, Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + }) + proxySrv := httptest.NewServer(proxy) + defer proxySrv.Close() + + resp, err := http.Get(proxySrv.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if string(body) != "proxied" { + t.Errorf("expected 'proxied', got %q", string(body)) + } +} + +func TestReverseProxy_402Success(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if strings.HasPrefix(r.Header.Get("Authorization"), "L402") { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"paid":true}`)) + return + } + w.Header().Set("WWW-Authenticate", `L402 macaroon="mac", invoice="inv"`) + w.WriteHeader(http.StatusPaymentRequired) + })) + defer upstream.Close() + + u, _ := url.Parse(upstream.URL) + proxy := ReverseProxy(u, Config{ + Wallet: wallet.NewMockWallet(nil), + Store: tokenstore.NewNoopStore(), + }) + proxySrv := httptest.NewServer(proxy) + defer proxySrv.Close() + + resp, err := http.Get(proxySrv.URL + "/paid-api") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if string(body) != `{"paid":true}` { + t.Errorf("unexpected body: %s", string(body)) + } + if callCount != 2 { + t.Errorf("expected 2 calls, got %d", callCount) + } +} + +func TestCopyHeaders_SkipsHopByHop(t *testing.T) { + src := make(http.Header) + src.Set("Content-Type", "application/json") + src.Set("Connection", "keep-alive") + src.Set("X-Custom", "value") + src.Set("Transfer-Encoding", "chunked") + + dst := make(http.Header) + copyHeaders(dst, src) + + if dst.Get("Content-Type") != "application/json" { + t.Error("Content-Type should be copied") + } + if dst.Get("X-Custom") != "value" { + t.Error("X-Custom should be copied") + } + if dst.Get("Connection") != "" { + t.Error("Connection should be skipped") + } + if dst.Get("Transfer-Encoding") != "" { + t.Error("Transfer-Encoding should be skipped") + } +}