Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions helix/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ type TokenProvider interface {
GetToken() *Token
}

// tokenContextKey is the context key for per-request token overrides.
type tokenContextKey struct{}

// WithToken returns a context that overrides the client-level token for a
// single request. This is useful when making concurrent requests that each
// require a different user token.
func WithToken(ctx context.Context, token *Token) context.Context {
return context.WithValue(ctx, tokenContextKey{}, token)
}

// tokenFromContext retrieves a per-request token override from context.
func tokenFromContext(ctx context.Context) *Token {
if token, ok := ctx.Value(tokenContextKey{}).(*Token); ok {
return token
}
return nil
}

// Client is a Twitch Helix API client.
type Client struct {
clientID string
Expand Down Expand Up @@ -382,9 +400,11 @@ func (c *Client) doOnceWithResponse(ctx context.Context, req *Request, result in
}
}

// Set authorization
// Set authorization (per-request context override takes precedence)
var token *Token
if c.authClient != nil {
if ctxToken := tokenFromContext(ctx); ctxToken != nil {
token = ctxToken
} else if c.authClient != nil {
token = c.authClient.GetToken()
} else if c.tokenProvider != nil {
token = c.tokenProvider.GetToken()
Expand Down
118 changes: 118 additions & 0 deletions helix/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1223,3 +1223,121 @@ func TestClient_ResponseBodyReadError(t *testing.T) {
}
}
}

func TestWithToken_OverridesClientToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer per-request-token" {
t.Errorf("expected per-request token, got %s", got)
}
_ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}})
}))
defer server.Close()

authClient := NewAuthClient(AuthConfig{ClientID: "test"})
authClient.SetToken(&Token{AccessToken: "client-level-token"})
client := NewClient("test-client-id", authClient, WithBaseURL(server.URL))

ctx := WithToken(context.Background(), &Token{AccessToken: "per-request-token"})

req := &Request{
Method: "GET",
Endpoint: "/users",
}

var result Response[User]
if err := client.Do(ctx, req, &result); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}

func TestWithToken_FallsBackToClientToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer client-level-token" {
t.Errorf("expected client-level token, got %s", got)
}
_ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}})
}))
defer server.Close()

authClient := NewAuthClient(AuthConfig{ClientID: "test"})
authClient.SetToken(&Token{AccessToken: "client-level-token"})
client := NewClient("test-client-id", authClient, WithBaseURL(server.URL))

req := &Request{
Method: "GET",
Endpoint: "/users",
}

var result Response[User]
if err := client.Do(context.Background(), req, &result); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}

func TestWithToken_NilTokenFallsBack(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer client-level-token" {
t.Errorf("expected client-level token, got %s", got)
}
_ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}})
}))
defer server.Close()

authClient := NewAuthClient(AuthConfig{ClientID: "test"})
authClient.SetToken(&Token{AccessToken: "client-level-token"})
client := NewClient("test-client-id", authClient, WithBaseURL(server.URL))

// WithToken with nil should fall back to client token
ctx := WithToken(context.Background(), nil)

req := &Request{
Method: "GET",
Endpoint: "/users",
}

var result Response[User]
if err := client.Do(ctx, req, &result); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}

func TestWithToken_ConcurrentDifferentTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Echo the token back in a custom header so we can verify
w.Header().Set("X-Received-Token", r.Header.Get("Authorization"))
_ = json.NewEncoder(w).Encode(Response[User]{Data: []User{}})
}))
defer server.Close()

authClient := NewAuthClient(AuthConfig{ClientID: "test"})
authClient.SetToken(&Token{AccessToken: "client-level-token"})
client := NewClient("test-client-id", authClient, WithBaseURL(server.URL))

tokens := []string{"token-a", "token-b", "token-c"}
errs := make(chan error, len(tokens))

for _, tok := range tokens {
go func(token string) {
ctx := WithToken(context.Background(), &Token{AccessToken: token})
req := &Request{
Method: "GET",
Endpoint: "/users",
}
var result Response[User]
errs <- client.Do(ctx, req, &result)
}(tok)
}

for range tokens {
if err := <-errs; err != nil {
t.Errorf("unexpected error: %v", err)
}
}
}

func TestTokenFromContext_NoToken(t *testing.T) {
token := tokenFromContext(context.Background())
if token != nil {
t.Error("expected nil token from empty context")
}
}