From d411fb3fa3254c7e32ad75de089e88e4921aafc0 Mon Sep 17 00:00:00 2001 From: Przemek Malolepszy Date: Wed, 1 Oct 2025 13:33:10 +0200 Subject: [PATCH 1/5] feat: add WithPrompt option for AuthCodeURL --- apps/confidential/confidential.go | 28 +++++++++++++- apps/confidential/confidential_test.go | 53 ++++++++++++++++++++++++++ apps/public/public.go | 38 +++++++++++++++--- apps/public/public_test.go | 46 ++++++++++++++++++++++ 4 files changed, 158 insertions(+), 7 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 549d68ab..1c080e86 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -359,7 +359,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { - claims, loginHint, tenantID, domainHint string + claims, loginHint, tenantID, domainHint, prompt string } // AuthCodeURLOption is implemented by options for AuthCodeURL @@ -369,7 +369,7 @@ type AuthCodeURLOption interface { // AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in. // -// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] +// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt] func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -382,6 +382,7 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint + ap.Prompt = o.prompt return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } @@ -431,6 +432,29 @@ func WithDomainHint(domain string) interface { } } +// WithPrompt adds prompt query parameter in the auth url. +func WithPrompt(prompt string) interface { + AuthCodeURLOption + options.CallOption +} { + return struct { + AuthCodeURLOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *authCodeURLOptions: + t.prompt = prompt + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. // This option is valid for any token acquisition method. diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 165a662f..f2e9a3fa 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -1774,6 +1774,59 @@ func TestWithDomainHint(t *testing.T) { } } +func TestWithPrompt(t *testing.T) { + prompt := "login" + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + client.base.Token.AccessTokens = &fake.AccessTokens{} + client.base.Token.Authority = &fake.Authority{} + client.base.Token.Resolver = &fake.ResolveEndpoints{} + for _, expectPrompt := range []bool{true, false} { + t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) { + validate := func(v url.Values) error { + if !v.Has("prompt") { + if !expectPrompt { + return nil + } + return errors.New("expected a prompt") + } else if !expectPrompt { + return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) + } + + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt { + err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) + } + return err + } + var urlOpts []AuthCodeURLOption + if expectPrompt { + urlOpts = append(urlOpts, WithPrompt(prompt)) + } + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + print("actual URL: " + u) + if err == nil { + var parsed *url.URL + parsed, err = url.Parse(u) + if err == nil { + err = validate(parsed.Query()) + } + } + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestWithAuthenticationScheme(t *testing.T) { ctx := context.Background() authScheme := mock.NewTestAuthnScheme() diff --git a/apps/public/public.go b/apps/public/public.go index 797c086c..4e64f92d 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -149,7 +149,7 @@ func New(clientID string, options ...Option) (Client, error) { // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { - claims, loginHint, tenantID, domainHint string + claims, loginHint, tenantID, domainHint, prompt string } // AuthCodeURLOption is implemented by options for AuthCodeURL @@ -159,7 +159,7 @@ type AuthCodeURLOption interface { // AuthCodeURL creates a URL used to acquire an authorization code. // -// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] +// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt] func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -172,6 +172,7 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint + ap.Prompt = o.prompt return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } @@ -526,9 +527,9 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error { // interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. type interactiveAuthOptions struct { - claims, domainHint, loginHint, redirectURI, tenantID string - openURL func(url string) error - authnScheme AuthenticationScheme + claims, domainHint, loginHint, redirectURI, tenantID, prompt string + openURL func(url string) error + authnScheme AuthenticationScheme } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -590,6 +591,33 @@ func WithDomainHint(domain string) interface { } } +// WithPrompt adds the IdP prompt query parameter in the auth url. +func WithPrompt(prompt string) interface { + AcquireInteractiveOption + AuthCodeURLOption + options.CallOption +} { + return struct { + AcquireInteractiveOption + AuthCodeURLOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *authCodeURLOptions: + t.prompt = prompt + case *interactiveAuthOptions: + t.prompt = prompt + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithRedirectURI sets a port for the local server used in interactive authentication, for // example http://localhost:port. All URI components other than the port are ignored. func WithRedirectURI(redirectURI string) interface { diff --git a/apps/public/public_test.go b/apps/public/public_test.go index fa019ca5..5b009cb1 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -935,6 +935,52 @@ func TestWithDomainHint(t *testing.T) { } } +func TestWithPrompt(t *testing.T) { + prompt := "login" + client, err := New("client-id") + if err != nil { + t.Fatal(err) + } + client.base.Token.AccessTokens = &fake.AccessTokens{} + client.base.Token.Authority = &fake.Authority{} + client.base.Token.Resolver = &fake.ResolveEndpoints{} + for _, expectPrompt := range []bool{true, false} { + t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) { + validate := func(v url.Values) error { + if !v.Has("prompt") { + if !expectPrompt { + return nil + } + return errors.New("expected a prompt") + } else if !expectPrompt { + return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) + } + + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt { + err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) + } + return err + } + var urlOpts []AuthCodeURLOption + if expectPrompt { + urlOpts = append(urlOpts, WithPrompt(prompt)) + } + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + print("actual URL: " + u) + if err == nil { + var parsed *url.URL + parsed, err = url.Parse(u) + if err == nil { + err = validate(parsed.Query()) + } + } + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestWithAuthenticationScheme(t *testing.T) { clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo, tenant := "login.microsoftonline.com", "tenant" From e52ff6b7f0f4cccc3fa41140b2e6e686ff244564 Mon Sep 17 00:00:00 2001 From: Przemek Malolepszy Date: Wed, 1 Oct 2025 15:43:14 +0200 Subject: [PATCH 2/5] chore: clean print statements, and use enum for prompt --- apps/confidential/confidential.go | 4 ++-- apps/confidential/confidential_test.go | 6 +++--- apps/internal/shared/shared.go | 26 ++++++++++++++++++++++++++ apps/public/public.go | 8 ++------ apps/public/public_test.go | 6 +++--- 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 1c080e86..e90aa5c4 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -433,7 +433,7 @@ func WithDomainHint(domain string) interface { } // WithPrompt adds prompt query parameter in the auth url. -func WithPrompt(prompt string) interface { +func WithPrompt(prompt shared.Prompt) interface { AuthCodeURLOption options.CallOption } { @@ -445,7 +445,7 @@ func WithPrompt(prompt string) interface { func(a any) error { switch t := a.(type) { case *authCodeURLOptions: - t.prompt = prompt + t.prompt = prompt.String() default: return fmt.Errorf("unexpected options type %T", a) } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index f2e9a3fa..23e68afe 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -33,6 +33,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request @@ -1775,7 +1776,7 @@ func TestWithDomainHint(t *testing.T) { } func TestWithPrompt(t *testing.T) { - prompt := "login" + prompt := shared.PromptLogin cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) @@ -1802,7 +1803,7 @@ func TestWithPrompt(t *testing.T) { return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) } - if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt { + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() { err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) } return err @@ -1812,7 +1813,6 @@ func TestWithPrompt(t *testing.T) { urlOpts = append(urlOpts, WithPrompt(prompt)) } u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) - print("actual URL: " + u) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) diff --git a/apps/internal/shared/shared.go b/apps/internal/shared/shared.go index d8ab7135..77376d6f 100644 --- a/apps/internal/shared/shared.go +++ b/apps/internal/shared/shared.go @@ -70,3 +70,29 @@ func (acc Account) IsZero() bool { // DefaultClient is our default shared HTTP client. var DefaultClient = &http.Client{} + +type Prompt int64 + +const ( + PromptNone Prompt = iota + PromptLogin + PromptSelectAccount + PromptConsent + PromptCreate +) + +func (p Prompt) String() string { + switch p { + case PromptNone: + return "none" + case PromptLogin: + return "login" + case PromptSelectAccount: + return "select_account" + case PromptConsent: + return "consent" + case PromptCreate: + return "create" + } + return "" +} diff --git a/apps/public/public.go b/apps/public/public.go index 4e64f92d..8ce892c1 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -592,13 +592,11 @@ func WithDomainHint(domain string) interface { } // WithPrompt adds the IdP prompt query parameter in the auth url. -func WithPrompt(prompt string) interface { - AcquireInteractiveOption +func WithPrompt(prompt shared.Prompt) interface { AuthCodeURLOption options.CallOption } { return struct { - AcquireInteractiveOption AuthCodeURLOption options.CallOption }{ @@ -606,9 +604,7 @@ func WithPrompt(prompt string) interface { func(a any) error { switch t := a.(type) { case *authCodeURLOptions: - t.prompt = prompt - case *interactiveAuthOptions: - t.prompt = prompt + t.prompt = prompt.String() default: return fmt.Errorf("unexpected options type %T", a) } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index 5b009cb1..f4ec6350 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -22,6 +22,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) @@ -936,7 +937,7 @@ func TestWithDomainHint(t *testing.T) { } func TestWithPrompt(t *testing.T) { - prompt := "login" + prompt := shared.PromptSelectAccount client, err := New("client-id") if err != nil { t.Fatal(err) @@ -956,7 +957,7 @@ func TestWithPrompt(t *testing.T) { return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) } - if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt { + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() { err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) } return err @@ -966,7 +967,6 @@ func TestWithPrompt(t *testing.T) { urlOpts = append(urlOpts, WithPrompt(prompt)) } u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) - print("actual URL: " + u) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) From 4e11b8e2011a583531d56e18a070cdbbe0dd9691 Mon Sep 17 00:00:00 2001 From: Przemek Malolepszy Date: Thu, 2 Oct 2025 11:03:39 +0200 Subject: [PATCH 3/5] chore: add WithPrompt to interactive flow --- apps/public/public.go | 6 +++++- apps/public/public_test.go | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/apps/public/public.go b/apps/public/public.go index 8ce892c1..3c8d1aed 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -593,10 +593,12 @@ func WithDomainHint(domain string) interface { // WithPrompt adds the IdP prompt query parameter in the auth url. func WithPrompt(prompt shared.Prompt) interface { + AcquireInteractiveOption AuthCodeURLOption options.CallOption } { return struct { + AcquireInteractiveOption AuthCodeURLOption options.CallOption }{ @@ -605,6 +607,8 @@ func WithPrompt(prompt shared.Prompt) interface { switch t := a.(type) { case *authCodeURLOptions: t.prompt = prompt.String() + case *interactiveAuthOptions: + t.prompt = prompt.String() default: return fmt.Errorf("unexpected options type %T", a) } @@ -698,7 +702,7 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, authParams.LoginHint = o.loginHint authParams.DomainHint = o.domainHint authParams.State = uuid.New().String() - authParams.Prompt = "select_account" + authParams.Prompt = o.prompt if o.authnScheme != nil { authParams.AuthnScheme = o.authnScheme } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index f4ec6350..23a4b184 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -44,16 +44,16 @@ func fakeBrowserOpenURL(authURL string) error { if m := q.Get("code_challenge_method"); m != "S256" { return fmt.Errorf("unexpected code_challenge_method '%s'", m) } - if q.Get("prompt") == "" { - return errors.New("missing query param 'prompt") - } + // if q.Get("prompt") == "" { + // return errors.New("missing query param 'prompt") + // } state := q.Get("state") if state == "" { return errors.New("missing query param 'state'") } redirect := q.Get("redirect_uri") if redirect == "" { - return errors.New("missing query param 'redirect_uri'") + return errors.New(" 'redirect_uri'") } // now send the info to our local redirect server resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state)) @@ -937,7 +937,7 @@ func TestWithDomainHint(t *testing.T) { } func TestWithPrompt(t *testing.T) { - prompt := shared.PromptSelectAccount + prompt := shared.PromptCreate client, err := New("client-id") if err != nil { t.Fatal(err) @@ -947,6 +947,7 @@ func TestWithPrompt(t *testing.T) { client.base.Token.Resolver = &fake.ResolveEndpoints{} for _, expectPrompt := range []bool{true, false} { t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) { + called := false validate := func(v url.Values) error { if !v.Has("prompt") { if !expectPrompt { @@ -962,10 +963,36 @@ func TestWithPrompt(t *testing.T) { } return err } + browserOpenURL := func(authURL string) error { + called = true + parsed, err := url.Parse(authURL) + if err != nil { + return err + } + query, err := url.ParseQuery(parsed.RawQuery) + if err != nil { + return err + } + if err = validate(query); err != nil { + t.Fatal(err) + return err + } + // this helper validates the other params and completes the redirect + return fakeBrowserOpenURL(authURL) + } + acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)} var urlOpts []AuthCodeURLOption if expectPrompt { + acquireOpts = append(acquireOpts, WithPrompt(prompt)) urlOpts = append(urlOpts, WithPrompt(prompt)) } + _, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...) + if err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("browserOpenURL wasn't called") + } u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL From 9a8a04431ea2e575cb7fa24503754bec3f33821a Mon Sep 17 00:00:00 2001 From: Przemek Malolepszy <39582596+szogoon@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:43:26 +0200 Subject: [PATCH 4/5] chore: revert to original error message in fakeBrowserOpenURL Co-authored-by: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> --- apps/public/public_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/public/public_test.go b/apps/public/public_test.go index 23a4b184..f6e3eccf 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -53,7 +53,7 @@ func fakeBrowserOpenURL(authURL string) error { } redirect := q.Get("redirect_uri") if redirect == "" { - return errors.New(" 'redirect_uri'") + return errors.New("missing redirect param 'redirect_uri'") } // now send the info to our local redirect server resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state)) From e1d9ce115e3c5f300f43f7fc2e3f7464961c6685 Mon Sep 17 00:00:00 2001 From: Przemek Malolepszy Date: Mon, 27 Oct 2025 15:55:57 +0100 Subject: [PATCH 5/5] chore: set select_account as default prompt for AcquireTokenInteractive --- apps/public/public.go | 6 +++++- apps/public/public_test.go | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/apps/public/public.go b/apps/public/public.go index 3c8d1aed..97c61545 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -702,7 +702,11 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, authParams.LoginHint = o.loginHint authParams.DomainHint = o.domainHint authParams.State = uuid.New().String() - authParams.Prompt = o.prompt + if o.prompt != "" { + authParams.Prompt = o.prompt + } else { + authParams.Prompt = shared.PromptSelectAccount.String() + } if o.authnScheme != nil { authParams.AuthnScheme = o.authnScheme } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index f6e3eccf..f024a529 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -44,9 +44,9 @@ func fakeBrowserOpenURL(authURL string) error { if m := q.Get("code_challenge_method"); m != "S256" { return fmt.Errorf("unexpected code_challenge_method '%s'", m) } - // if q.Get("prompt") == "" { - // return errors.New("missing query param 'prompt") - // } + if q.Get("prompt") == "" { + return errors.New("missing query param 'prompt") + } state := q.Get("state") if state == "" { return errors.New("missing query param 'state'")