From 2ac282ab9419cecfa92c60c373762d90d5600297 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 16:30:03 +0200 Subject: [PATCH 1/2] add refresh to oauthdevice.Client From 2f24f44607d309f20908c77abf81e68eecfb2c47 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 16:34:08 +0200 Subject: [PATCH 2/2] oauthdevice: add RefreshToken field and Refresh method --- internal/oauthdevice/device_flow.go | 53 ++++++++++++++++++++++++ internal/oauthdevice/device_flow_test.go | 41 ++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go index c278dd4ba3..b13f60af70 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauthdevice/device_flow.go @@ -68,6 +68,7 @@ type Client interface { Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) + Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) } type httpClient struct { @@ -307,3 +308,55 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str return &tokenResp, nil } + +// Refresh exchanges a refresh token for a new access token. +func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) { + endpoint = strings.TrimRight(endpoint, "/") + + config, err := c.Discover(ctx, endpoint) + if err != nil { + return nil, errors.Wrap(err, "OIDC discovery failed") + } + + if config.TokenEndpoint == "" { + return nil, errors.New("token endpoint not found in OIDC configuration") + } + + data := url.Values{} + data.Set("client_id", c.clientID) + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + + req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating refresh token request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "refresh token request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading refresh token response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, errors.Newf("refresh token failed: %s: %s", errResp.Error, errResp.ErrorDescription) + } + return nil, errors.Newf("refresh token failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "parsing refresh token response") + } + + return &tokenResp, nil +} \ No newline at end of file diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauthdevice/device_flow_test.go index e60e1f9b1a..db7fedf38c 100644 --- a/internal/oauthdevice/device_flow_test.go +++ b/internal/oauthdevice/device_flow_test.go @@ -507,3 +507,44 @@ func TestPoll_ContextCancellation(t *testing.T) { t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err) } } + +func TestRefresh_Success(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if got := r.FormValue("grant_type"); got != "refresh_token" { + t.Errorf("grant_type = %q, want %q", got, "refresh_token") + } + if got := r.FormValue("refresh_token"); got != "test-refresh-token" { + t.Errorf("refresh_token = %q, want %q", got, "test-refresh-token") + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + }, + }, + }) + defer server.Close() + + client := NewClient(DefaultClientID) + resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token") + if err != nil { + t.Fatalf("Refresh() error = %v", err) + } + + if resp.AccessToken != "new-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "new-access-token") + } + if resp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token") + } +}