diff --git a/cmd/grounds/commands/login.go b/cmd/grounds/commands/login.go index 4504556..dab4b15 100644 --- a/cmd/grounds/commands/login.go +++ b/cmd/grounds/commands/login.go @@ -45,7 +45,7 @@ func NewLoginCommand() *cobra.Command { fmt.Fprintln(cmd.OutOrStdout(), " Verification code:", dc.UserCode) _ = browser.OpenURL(dc.VerificationURIComplete) - tok, err := device.PollToken(ctx, dc.DeviceCode, dc.Interval, dc.ExpiresIn) + tok, err := device.PollToken(ctx, dc.DeviceCode, dc.CodeVerifier, dc.Interval, dc.ExpiresIn) if err != nil { return err } diff --git a/internal/auth/device.go b/internal/auth/device.go index 2a56f2f..3673fed 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -2,6 +2,9 @@ package auth import ( "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "io" @@ -18,6 +21,13 @@ type DeviceCodeResponse struct { VerificationURIComplete string `json:"verification_uri_complete"` ExpiresIn int `json:"expires_in"` Interval int `json:"interval"` + // CodeVerifier is the PKCE secret generated client-side by + // StartDevice. The caller must pass it back to PollToken so the + // token endpoint can validate the challenge that was bound to the + // device_code at request time. Not part of the OIDC response — + // stitched into the struct here so the device-flow API stays + // stateless across the two RPCs. + CodeVerifier string `json:"-"` } type TokenResponse struct { @@ -36,9 +46,15 @@ type DeviceClient struct { } func (d *DeviceClient) StartDevice(ctx context.Context) (*DeviceCodeResponse, error) { + verifier, challenge, err := newPKCE() + if err != nil { + return nil, fmt.Errorf("pkce: %w", err) + } body := url.Values{ - "client_id": {d.ClientID}, - "scope": {"openid profile email"}, + "client_id": {d.ClientID}, + "scope": {"openid profile email"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, } req, _ := http.NewRequestWithContext(ctx, "POST", d.Issuer+"/protocol/openid-connect/auth/device", @@ -60,12 +76,31 @@ func (d *DeviceClient) StartDevice(ctx context.Context) (*DeviceCodeResponse, er if out.Interval == 0 { out.Interval = 5 } + out.CodeVerifier = verifier return out, nil } +// newPKCE generates an RFC 7636 PKCE pair: a 43-character URL-safe +// random verifier and its SHA-256 challenge. Keycloak's device-flow +// implementation requires both since the recent enforcement of +// `code_challenge_method` on the device endpoint. +func newPKCE() (verifier, challenge string, err error) { + buf := make([]byte, 32) + if _, err = rand.Read(buf); err != nil { + return "", "", err + } + verifier = base64.RawURLEncoding.EncodeToString(buf) + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} + // PollToken loops until the user authorises in the browser, the device -// code expires, or ctx is cancelled. Returns the token response on success. -func (d *DeviceClient) PollToken(ctx context.Context, deviceCode string, interval, expiresIn int) (*TokenResponse, error) { +// code expires, or ctx is cancelled. `codeVerifier` is the PKCE secret +// returned by StartDevice — passing it here lets Keycloak validate the +// challenge that was bound to the device_code. Returns the token +// response on success. +func (d *DeviceClient) PollToken(ctx context.Context, deviceCode, codeVerifier string, interval, expiresIn int) (*TokenResponse, error) { deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) tick := time.NewTicker(time.Duration(interval) * time.Second) defer tick.Stop() @@ -81,9 +116,10 @@ func (d *DeviceClient) PollToken(ctx context.Context, deviceCode string, interva } body := url.Values{ - "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, - "device_code": {deviceCode}, - "client_id": {d.ClientID}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "device_code": {deviceCode}, + "client_id": {d.ClientID}, + "code_verifier": {codeVerifier}, } req, _ := http.NewRequestWithContext(ctx, "POST", d.Issuer+"/protocol/openid-connect/token", diff --git a/internal/auth/device_test.go b/internal/auth/device_test.go index 676cb27..b82822b 100644 --- a/internal/auth/device_test.go +++ b/internal/auth/device_test.go @@ -11,10 +11,14 @@ import ( ) func TestStartDevice(t *testing.T) { + var gotChallenge, gotMethod string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasSuffix(r.URL.Path, "/auth/device") { t.Fatalf("path = %s", r.URL.Path) } + r.ParseForm() + gotChallenge = r.Form.Get("code_challenge") + gotMethod = r.Form.Get("code_challenge_method") json.NewEncoder(w).Encode(DeviceCodeResponse{ DeviceCode: "dc", UserCode: "ABCD-EFGH", @@ -33,12 +37,26 @@ func TestStartDevice(t *testing.T) { if res.UserCode != "ABCD-EFGH" { t.Errorf("UserCode = %q", res.UserCode) } + // PKCE — Keycloak's device endpoint requires this since recent + // versions; without it the request fails 400 invalid_request. + if gotMethod != "S256" { + t.Errorf("code_challenge_method = %q, want S256", gotMethod) + } + if gotChallenge == "" { + t.Error("code_challenge missing on /auth/device request") + } + if res.CodeVerifier == "" { + t.Error("CodeVerifier should be populated for PollToken") + } } func TestPollToken_Success(t *testing.T) { call := 0 + var gotVerifier string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { call++ + r.ParseForm() + gotVerifier = r.Form.Get("code_verifier") if call == 1 { // First poll: authorization_pending w.WriteHeader(http.StatusBadRequest) @@ -51,13 +69,16 @@ func TestPollToken_Success(t *testing.T) { c := &DeviceClient{Issuer: srv.URL, ClientID: "grounds-cli", HTTP: srv.Client()} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - tok, err := c.PollToken(ctx, "dc", 1, 60) + tok, err := c.PollToken(ctx, "dc", "verifier-123", 1, 60) if err != nil { t.Fatalf("err: %v", err) } if tok.AccessToken != "at" { t.Errorf("AccessToken = %q", tok.AccessToken) } + if gotVerifier != "verifier-123" { + t.Errorf("code_verifier = %q, want verifier-123", gotVerifier) + } } func TestRefresh(t *testing.T) {