diff --git a/helix/client.go b/helix/client.go index b6b9be2..e9b64a3 100644 --- a/helix/client.go +++ b/helix/client.go @@ -26,6 +26,24 @@ type TokenProvider interface { GetToken() *Token } +// tokenContextKey is the context key for per-request token overrides. +type tokenContextKey struct{} + +// WithToken returns a context that overrides the client-level token for a +// single request. This is useful when making concurrent requests that each +// require a different user token. +func WithToken(ctx context.Context, token *Token) context.Context { + return context.WithValue(ctx, tokenContextKey{}, token) +} + +// tokenFromContext retrieves a per-request token override from context. +func tokenFromContext(ctx context.Context) *Token { + if token, ok := ctx.Value(tokenContextKey{}).(*Token); ok { + return token + } + return nil +} + // Client is a Twitch Helix API client. type Client struct { clientID string @@ -382,9 +400,11 @@ func (c *Client) doOnceWithResponse(ctx context.Context, req *Request, result in } } - // Set authorization + // Set authorization (per-request context override takes precedence) var token *Token - if c.authClient != nil { + if ctxToken := tokenFromContext(ctx); ctxToken != nil { + token = ctxToken + } else if c.authClient != nil { token = c.authClient.GetToken() } else if c.tokenProvider != nil { token = c.tokenProvider.GetToken() diff --git a/helix/client_test.go b/helix/client_test.go index 97fd2bc..7c62928 100644 --- a/helix/client_test.go +++ b/helix/client_test.go @@ -1223,3 +1223,121 @@ func TestClient_ResponseBodyReadError(t *testing.T) { } } } + +func TestWithToken_OverridesClientToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer per-request-token" { + t.Errorf("expected per-request token, got %s", got) + } + _ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}}) + })) + defer server.Close() + + authClient := NewAuthClient(AuthConfig{ClientID: "test"}) + authClient.SetToken(&Token{AccessToken: "client-level-token"}) + client := NewClient("test-client-id", authClient, WithBaseURL(server.URL)) + + ctx := WithToken(context.Background(), &Token{AccessToken: "per-request-token"}) + + req := &Request{ + Method: "GET", + Endpoint: "/users", + } + + var result Response[User] + if err := client.Do(ctx, req, &result); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWithToken_FallsBackToClientToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer client-level-token" { + t.Errorf("expected client-level token, got %s", got) + } + _ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}}) + })) + defer server.Close() + + authClient := NewAuthClient(AuthConfig{ClientID: "test"}) + authClient.SetToken(&Token{AccessToken: "client-level-token"}) + client := NewClient("test-client-id", authClient, WithBaseURL(server.URL)) + + req := &Request{ + Method: "GET", + Endpoint: "/users", + } + + var result Response[User] + if err := client.Do(context.Background(), req, &result); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWithToken_NilTokenFallsBack(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer client-level-token" { + t.Errorf("expected client-level token, got %s", got) + } + _ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}}) + })) + defer server.Close() + + authClient := NewAuthClient(AuthConfig{ClientID: "test"}) + authClient.SetToken(&Token{AccessToken: "client-level-token"}) + client := NewClient("test-client-id", authClient, WithBaseURL(server.URL)) + + // WithToken with nil should fall back to client token + ctx := WithToken(context.Background(), nil) + + req := &Request{ + Method: "GET", + Endpoint: "/users", + } + + var result Response[User] + if err := client.Do(ctx, req, &result); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWithToken_ConcurrentDifferentTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo the token back in a custom header so we can verify + w.Header().Set("X-Received-Token", r.Header.Get("Authorization")) + _ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}}) + })) + defer server.Close() + + authClient := NewAuthClient(AuthConfig{ClientID: "test"}) + authClient.SetToken(&Token{AccessToken: "client-level-token"}) + client := NewClient("test-client-id", authClient, WithBaseURL(server.URL)) + + tokens := []string{"token-a", "token-b", "token-c"} + errs := make(chan error, len(tokens)) + + for _, tok := range tokens { + go func(token string) { + ctx := WithToken(context.Background(), &Token{AccessToken: token}) + req := &Request{ + Method: "GET", + Endpoint: "/users", + } + var result Response[User] + errs <- client.Do(ctx, req, &result) + }(tok) + } + + for range tokens { + if err := <-errs; err != nil { + t.Errorf("unexpected error: %v", err) + } + } +} + +func TestTokenFromContext_NoToken(t *testing.T) { + token := tokenFromContext(context.Background()) + if token != nil { + t.Error("expected nil token from empty context") + } +}