Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/grounds/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewLoginCommand() *cobra.Command {
fmt.Fprintln(cmd.OutOrStdout(), " Verification code:", dc.UserCode)
_ = browser.OpenURL(dc.VerificationURIComplete)

tok, err := device.PollToken(ctx, dc.DeviceCode, dc.Interval, dc.ExpiresIn)
tok, err := device.PollToken(ctx, dc.DeviceCode, dc.CodeVerifier, dc.Interval, dc.ExpiresIn)
if err != nil {
return err
}
Expand Down
50 changes: 43 additions & 7 deletions internal/auth/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package auth

import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand All @@ -18,6 +21,13 @@ type DeviceCodeResponse struct {
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
// CodeVerifier is the PKCE secret generated client-side by
// StartDevice. The caller must pass it back to PollToken so the
// token endpoint can validate the challenge that was bound to the
// device_code at request time. Not part of the OIDC response —
// stitched into the struct here so the device-flow API stays
// stateless across the two RPCs.
CodeVerifier string `json:"-"`
}

type TokenResponse struct {
Expand All @@ -36,9 +46,15 @@ type DeviceClient struct {
}

func (d *DeviceClient) StartDevice(ctx context.Context) (*DeviceCodeResponse, error) {
verifier, challenge, err := newPKCE()
if err != nil {
return nil, fmt.Errorf("pkce: %w", err)
}
body := url.Values{
"client_id": {d.ClientID},
"scope": {"openid profile email"},
"client_id": {d.ClientID},
"scope": {"openid profile email"},
"code_challenge": {challenge},
"code_challenge_method": {"S256"},
}
req, _ := http.NewRequestWithContext(ctx, "POST",
d.Issuer+"/protocol/openid-connect/auth/device",
Expand All @@ -60,12 +76,31 @@ func (d *DeviceClient) StartDevice(ctx context.Context) (*DeviceCodeResponse, er
if out.Interval == 0 {
out.Interval = 5
}
out.CodeVerifier = verifier
return out, nil
}

// newPKCE generates an RFC 7636 PKCE pair: a 43-character URL-safe
// random verifier and its SHA-256 challenge. Keycloak's device-flow
// implementation requires both since the recent enforcement of
// `code_challenge_method` on the device endpoint.
func newPKCE() (verifier, challenge string, err error) {
buf := make([]byte, 32)
if _, err = rand.Read(buf); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(buf)
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
return verifier, challenge, nil
}

// PollToken loops until the user authorises in the browser, the device
// code expires, or ctx is cancelled. Returns the token response on success.
func (d *DeviceClient) PollToken(ctx context.Context, deviceCode string, interval, expiresIn int) (*TokenResponse, error) {
// code expires, or ctx is cancelled. `codeVerifier` is the PKCE secret
// returned by StartDevice — passing it here lets Keycloak validate the
// challenge that was bound to the device_code. Returns the token
// response on success.
func (d *DeviceClient) PollToken(ctx context.Context, deviceCode, codeVerifier string, interval, expiresIn int) (*TokenResponse, error) {
deadline := time.Now().Add(time.Duration(expiresIn) * time.Second)
tick := time.NewTicker(time.Duration(interval) * time.Second)
defer tick.Stop()
Expand All @@ -81,9 +116,10 @@ func (d *DeviceClient) PollToken(ctx context.Context, deviceCode string, interva
}

body := url.Values{
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {deviceCode},
"client_id": {d.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {deviceCode},
"client_id": {d.ClientID},
"code_verifier": {codeVerifier},
}
req, _ := http.NewRequestWithContext(ctx, "POST",
d.Issuer+"/protocol/openid-connect/token",
Expand Down
23 changes: 22 additions & 1 deletion internal/auth/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ import (
)

func TestStartDevice(t *testing.T) {
var gotChallenge, gotMethod string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasSuffix(r.URL.Path, "/auth/device") {
t.Fatalf("path = %s", r.URL.Path)
}
r.ParseForm()
gotChallenge = r.Form.Get("code_challenge")
gotMethod = r.Form.Get("code_challenge_method")
json.NewEncoder(w).Encode(DeviceCodeResponse{
DeviceCode: "dc",
UserCode: "ABCD-EFGH",
Expand All @@ -33,12 +37,26 @@ func TestStartDevice(t *testing.T) {
if res.UserCode != "ABCD-EFGH" {
t.Errorf("UserCode = %q", res.UserCode)
}
// PKCE — Keycloak's device endpoint requires this since recent
// versions; without it the request fails 400 invalid_request.
if gotMethod != "S256" {
t.Errorf("code_challenge_method = %q, want S256", gotMethod)
}
if gotChallenge == "" {
t.Error("code_challenge missing on /auth/device request")
}
if res.CodeVerifier == "" {
t.Error("CodeVerifier should be populated for PollToken")
}
}

func TestPollToken_Success(t *testing.T) {
call := 0
var gotVerifier string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
r.ParseForm()
gotVerifier = r.Form.Get("code_verifier")
if call == 1 {
// First poll: authorization_pending
w.WriteHeader(http.StatusBadRequest)
Expand All @@ -51,13 +69,16 @@ func TestPollToken_Success(t *testing.T) {
c := &DeviceClient{Issuer: srv.URL, ClientID: "grounds-cli", HTTP: srv.Client()}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
tok, err := c.PollToken(ctx, "dc", 1, 60)
tok, err := c.PollToken(ctx, "dc", "verifier-123", 1, 60)
if err != nil {
t.Fatalf("err: %v", err)
}
if tok.AccessToken != "at" {
t.Errorf("AccessToken = %q", tok.AccessToken)
}
if gotVerifier != "verifier-123" {
t.Errorf("code_verifier = %q, want verifier-123", gotVerifier)
}
}

func TestRefresh(t *testing.T) {
Expand Down
Loading