diff --git a/access.go b/access.go index 5aa32dd..85f047e 100644 --- a/access.go +++ b/access.go @@ -162,36 +162,53 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques } func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest { - auth, err := CheckBasicAuth(r) + clientID := r.Form.Get("client_id") + if clientID == "" { + w.SetError(E_UNAUTHORIZED_CLIENT, "") + return nil + } + client, err := w.Storage.GetClient(clientID) if err != nil { w.SetError(E_SERVER_ERROR, "") w.InternalError = err return nil } + if client == nil { + w.SetError(E_UNAUTHORIZED_CLIENT, "") + return nil + } + var publicClient bool + if CheckClientSecret(client, ""){ + publicClient=true - var clientID string - var client Client - if auth == nil { - clientID = r.Form.Get("client_id") - if clientID == "" { - w.SetError(E_UNAUTHORIZED_CLIENT, "") - return nil - } - client = getClientWithoutSecret(clientID, w.Storage, w) - } else { + } else{ // get client authentication auth := GetClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil } - client = getClient(auth, w.Storage, w) + if !CheckClientSecret(client,auth.Password){ + w.SetError(E_UNAUTHORIZED_CLIENT, "") + return nil + } + } + + + var codeVerifier string + // Optional PKCE support (https://tools.ietf.org/html/rfc7636) + if codeVerifier = r.Form.Get("code_verifier"); len(codeVerifier) == 0 { + if s.Config.RequirePKCEForPublicClients && publicClient { + // https://tools.ietf.org/html/rfc7636#section-4.4.1 + w.SetError(E_INVALID_REQUEST, "code_verifier (rfc7636) required for public clients") + return nil + } } // generate access token ret := &AccessRequest{ Type: AUTHORIZATION_CODE, Code: r.Form.Get("code"), - CodeVerifier: r.Form.Get("code_verifier"), + CodeVerifier: codeVerifier, RedirectUri: r.Form.Get("redirect_uri"), GenerateRefresh: true, Expiration: s.Config.AccessExpiration, @@ -250,7 +267,7 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A } if ret.AuthorizeData.RedirectUri != ret.RedirectUri { w.SetError(E_INVALID_REQUEST, "") - w.InternalError = errors.New("Redirect uri is different") + w.InternalError = errors.New("redirect uri is different") return nil } diff --git a/access_test.go b/access_test.go index e073ae5..0c7f02f 100644 --- a/access_test.go +++ b/access_test.go @@ -24,6 +24,7 @@ func TestAccessAuthorizationCode(t *testing.T) { req.Form.Set("grant_type", string(AUTHORIZATION_CODE)) req.Form.Set("code", "9999") req.Form.Set("state", "a") + req.Form.Set("client_id", "1234") req.PostForm = make(url.Values) if ar := server.HandleAccessRequest(resp, req); ar != nil { @@ -54,6 +55,44 @@ func TestAccessAuthorizationCode(t *testing.T) { } } +func TestAccessAuthorizationCodePublicClientWithoutPKCE(t *testing.T) { + sconfig := NewServerConfig() + sconfig.AllowedAccessTypes = AllowedAccessType{AUTHORIZATION_CODE} + sconfig.RequirePKCEForPublicClients = true + server := NewServer(sconfig, NewTestingStorage()) + server.AccessTokenGen = &TestingAccessTokenGen{} + resp := server.NewResponse() + + req, err := http.NewRequest("POST", "http://localhost:14000/appauth", nil) + if err != nil { + t.Fatal(err) + } + + req.Form = make(url.Values) + req.Form.Set("grant_type", string(AUTHORIZATION_CODE)) + req.Form.Set("code", "9999") + req.Form.Set("state", "a") + req.Form.Set("client_id", "public-client") + req.PostForm = make(url.Values) + + if ar := server.HandleAccessRequest(resp, req); ar != nil { + ar.Authorized = true + server.FinishAccessRequest(resp, req, ar) + } + + if !resp.IsError { + t.Fatalf("Should be an error") + } + + if resp.ErrorId!="invalid_request" { + t.Fatalf("Unexpected error id: %s", resp.ErrorId) + } + + if resp.Type != DATA { + t.Fatalf("Response should be data") + } +} + func TestAccessRefreshToken(t *testing.T) { sconfig := NewServerConfig() sconfig.AllowedAccessTypes = AllowedAccessType{REFRESH_TOKEN} @@ -631,6 +670,7 @@ func TestAccessAuthorizationCodePKCE(t *testing.T) { req.Form.Set("code", "pkce-code") req.Form.Set("state", "a") req.Form.Set("code_verifier", test.Verifier) + req.Form.Set("client_id", "public-client") req.PostForm = make(url.Values) if ar := server.HandleAccessRequest(resp, req); ar != nil {