Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions internal/oauthdevice/device_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Client interface {
Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error)
Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error)
Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error)
Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error)
}

type httpClient struct {
Expand Down Expand Up @@ -307,3 +308,55 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str

return &tokenResp, nil
}

// Refresh exchanges a refresh token for a new access token.
func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this just magically get used to refresh in our client?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet! There is some code I still need to land which adds a transport which will do automatic refresh if you use api.Client

endpoint = strings.TrimRight(endpoint, "/")

config, err := c.Discover(ctx, endpoint)
if err != nil {
return nil, errors.Wrap(err, "OIDC discovery failed")
}

if config.TokenEndpoint == "" {
return nil, errors.New("token endpoint not found in OIDC configuration")
}

data := url.Values{}
data.Set("client_id", c.clientID)
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)

req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.Wrap(err, "creating refresh token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

resp, err := c.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "refresh token request failed")
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "reading refresh token response")
}

if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
return nil, errors.Newf("refresh token failed: %s: %s", errResp.Error, errResp.ErrorDescription)
}
return nil, errors.Newf("refresh token failed with status %d: %s", resp.StatusCode, string(body))
}

var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "parsing refresh token response")
}

return &tokenResp, nil
}
41 changes: 41 additions & 0 deletions internal/oauthdevice/device_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,44 @@ func TestPoll_ContextCancellation(t *testing.T) {
t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err)
}
}

func TestRefresh_Success(t *testing.T) {
server := newTestServer(t, testServerOptions{
handlers: map[string]http.HandlerFunc{
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if got := r.FormValue("grant_type"); got != "refresh_token" {
t.Errorf("grant_type = %q, want %q", got, "refresh_token")
}
if got := r.FormValue("refresh_token"); got != "test-refresh-token" {
t.Errorf("refresh_token = %q, want %q", got, "test-refresh-token")
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
},
},
})
defer server.Close()

client := NewClient(DefaultClientID)
resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token")
if err != nil {
t.Fatalf("Refresh() error = %v", err)
}

if resp.AccessToken != "new-access-token" {
t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "new-access-token")
}
if resp.RefreshToken != "new-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token")
}
}
Loading