diff --git a/pkg/auth/iam/iam.go b/pkg/auth/iam/iam.go index 624bebd..071ff48 100644 --- a/pkg/auth/iam/iam.go +++ b/pkg/auth/iam/iam.go @@ -147,85 +147,89 @@ func (filter *Filter) AuthAllowEmptySubdomain(opts ...FilterOption) restful.Filt func (filter *Filter) authFunc(allowEmptySubdomain bool, opts ...FilterOption) restful.FilterFunction { return func(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { - token, tokenFrom, err := parseAccessToken(req) - if err != nil { - logrus.Warn("unauthorized access: ", err) - logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ - ErrorCode: UnauthorizedAccess, - ErrorMessage: ErrorCodeMapping[UnauthorizedAccess], - }, restful.MIME_JSON)) + filter.authFuncImpl(req, resp, chain, allowEmptySubdomain, opts...) + } +} - return - } +func (filter *Filter) authFuncImpl(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, allowEmptySubdomain bool, opts ...FilterOption) { + token, tokenFrom, err := parseAccessToken(req) + if err != nil { + logrus.Warn("unauthorized access: ", err) + logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ + ErrorCode: UnauthorizedAccess, + ErrorMessage: ErrorCodeMapping[UnauthorizedAccess], + }, restful.MIME_JSON)) - claims, err := filter.iamClient.ValidateAndParseClaims(token) - if err != nil { - logrus.Warn("unauthorized access: ", err) - if err.Error() == ErrorCodeMapping[TokenIsExpired] { - logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ - ErrorCode: TokenIsExpired, - ErrorMessage: ErrorCodeMapping[TokenIsExpired], - }, restful.MIME_JSON)) - return - } + return + } + + claims, err := filter.iamClient.ValidateAndParseClaims(token) + if err != nil { + logrus.Warn("unauthorized access: ", err) + if err.Error() == ErrorCodeMapping[TokenIsExpired] { logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ - ErrorCode: UnauthorizedAccess, - ErrorMessage: ErrorCodeMapping[UnauthorizedAccess], + ErrorCode: TokenIsExpired, + ErrorMessage: ErrorCodeMapping[TokenIsExpired], }, restful.MIME_JSON)) return } + logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ + ErrorCode: UnauthorizedAccess, + ErrorMessage: ErrorCodeMapping[UnauthorizedAccess], + }, restful.MIME_JSON)) + return + } - req.SetAttribute(ClaimsAttribute, claims) + req.SetAttribute(ClaimsAttribute, claims) - if tokenFrom == tokenFromCookie { - valid := filter.validateRefererHeader(req, claims, allowEmptySubdomain) - if !valid { - logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ - ErrorCode: InvalidRefererHeader, - ErrorMessage: ErrorCodeMapping[InvalidRefererHeader], - }, restful.MIME_JSON)) + if tokenFrom == tokenFromCookie { + valid := filter.validateRefererHeader(req, claims, allowEmptySubdomain) + if !valid { + logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{ + ErrorCode: InvalidRefererHeader, + ErrorMessage: ErrorCodeMapping[InvalidRefererHeader], + }, restful.MIME_JSON)) - return - } + return } + } - if filter.options.SubdomainValidationEnabled && !allowEmptySubdomain { - if valid := validateSubdomainAgainstNamespace(getHost(req.Request), claims.Namespace, filter.options.SubdomainValidationExcludedNamespaces); !valid { - logIfErr(resp.WriteHeaderAndJson(http.StatusNotFound, ErrorResponse{ - ErrorCode: SubdomainMismatch, - ErrorMessage: "data not found: " + ErrorCodeMapping[SubdomainMismatch], - }, restful.MIME_JSON)) + if filter.options.SubdomainValidationEnabled && !allowEmptySubdomain { + if valid := validateSubdomainAgainstNamespace(getHost(req.Request), claims.Namespace, filter.options.SubdomainValidationExcludedNamespaces); !valid { + logIfErr(resp.WriteHeaderAndJson(http.StatusNotFound, ErrorResponse{ + ErrorCode: SubdomainMismatch, + ErrorMessage: "data not found: " + ErrorCodeMapping[SubdomainMismatch], + }, restful.MIME_JSON)) - return - } + return } + } - for _, opt := range opts { - if err = opt(req, filter.iamClient, claims); err != nil { - if svcErr, ok := err.(restful.ServiceError); ok { - logrus.Warn(svcErr.Message) - - var respErr ErrorResponse + for _, opt := range opts { + if err = opt(req, filter.iamClient, claims); err != nil { + if svcErr, ok := err.(restful.ServiceError); ok { + logrus.Warn(svcErr.Message) - err = json.Unmarshal([]byte(svcErr.Message), &respErr) - if err == nil { - logIfErr(resp.WriteHeaderAndJson(svcErr.Code, respErr, restful.MIME_JSON)) - } else { - logIfErr(resp.WriteErrorString(svcErr.Code, svcErr.Message)) - } + var respErr ErrorResponse - return + err = json.Unmarshal([]byte(svcErr.Message), &respErr) + if err == nil { + logIfErr(resp.WriteHeaderAndJson(svcErr.Code, respErr, restful.MIME_JSON)) + } else { + logIfErr(resp.WriteErrorString(svcErr.Code, svcErr.Message)) } - logrus.Warn(err) - logIfErr(resp.WriteErrorString(http.StatusUnauthorized, err.Error())) - return } - } - chain.ProcessFilter(req, resp) + logrus.Warn(err) + logIfErr(resp.WriteErrorString(http.StatusUnauthorized, err.Error())) + + return + } } + + chain.ProcessFilter(req, resp) } // PublicAuth returns a filter that allow unauthenticate request and request with valid access token in auth header or cookie @@ -241,41 +245,45 @@ func (filter *Filter) authFunc(allowEmptySubdomain bool, opts ...FilterOption) r // ) func (filter *Filter) PublicAuth(opts ...FilterOption) restful.FilterFunction { return func(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { - token, tokenFrom, err := parseAccessToken(req) - if err != nil { + filter.publicAuth(req, resp, chain, opts...) + } +} + +func (filter *Filter) publicAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, opts ...FilterOption) { + token, tokenFrom, err := parseAccessToken(req) + if err != nil { + chain.ProcessFilter(req, resp) + return + } + + claims, err := filter.iamClient.ValidateAndParseClaims(token) + if err != nil { + logrus.Warn("unauthorized access for public endpoint: ", err) + chain.ProcessFilter(req, resp) + return + } + + req.SetAttribute(ClaimsAttribute, claims) + + if tokenFrom == tokenFromCookie { + valid := filter.validateRefererHeader(req, claims, false) + if !valid { + req.SetAttribute(ClaimsAttribute, nil) chain.ProcessFilter(req, resp) return } + } - claims, err := filter.iamClient.ValidateAndParseClaims(token) - if err != nil { - logrus.Warn("unauthorized access for public endpoint: ", err) + for _, opt := range opts { + if err = opt(req, filter.iamClient, claims); err != nil { + logrus.Warn(err) + req.SetAttribute(ClaimsAttribute, nil) chain.ProcessFilter(req, resp) return } - - req.SetAttribute(ClaimsAttribute, claims) - - if tokenFrom == tokenFromCookie { - valid := filter.validateRefererHeader(req, claims, false) - if !valid { - req.SetAttribute(ClaimsAttribute, nil) - chain.ProcessFilter(req, resp) - return - } - } - - for _, opt := range opts { - if err = opt(req, filter.iamClient, claims); err != nil { - logrus.Warn(err) - req.SetAttribute(ClaimsAttribute, nil) - chain.ProcessFilter(req, resp) - return - } - } - - chain.ProcessFilter(req, resp) } + + chain.ProcessFilter(req, resp) } // RetrieveJWTClaims is a convenience function to retrieve JWT claims @@ -301,32 +309,35 @@ func WithValidUser() FilterOption { // WithPermission filters request with valid permission only func WithPermission(permission *iam.Permission) FilterOption { return func(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims) error { - requiredPermissionResources := make(map[string]string) - requiredPermissionResources["{namespace}"] = req.PathParameter("namespace") - requiredPermissionResources["{userId}"] = req.PathParameter("userId") + return withPermission(req, iamClient, claims, permission) + } +} - valid, err := iamClient.ValidatePermission(claims, *permission, requiredPermissionResources) - if err != nil { - return respondError(http.StatusInternalServerError, InternalServerError, - "unable to validate permission: "+err.Error()) - } +func withPermission(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims, permission *iam.Permission) error { + requiredPermissionResources := make(map[string]string) + requiredPermissionResources["{namespace}"] = req.PathParameter("namespace") + requiredPermissionResources["{userId}"] = req.PathParameter("userId") - insufficientPermissionMessage := ErrorCodeMapping[InsufficientPermissions] - if DevStackTraceable { - action := ActionConverter(permission.Action) - insufficientPermissionMessage = fmt.Sprintf("%s. Required permission: %s [%s]", insufficientPermissionMessage, - permission.Resource, action) - } - if !valid { - return respondErrorWithRequiredPermission(http.StatusForbidden, InsufficientPermissions, - "access forbidden: "+insufficientPermissionMessage, Permission{ - Resource: permission.Resource, - Action: permission.Action, - }) - } + valid, err := iamClient.ValidatePermission(claims, *permission, requiredPermissionResources) + if err != nil { + return respondError(http.StatusInternalServerError, InternalServerError, + "unable to validate permission: "+err.Error()) + } - return nil + insufficientPermissionMessage := ErrorCodeMapping[InsufficientPermissions] + if DevStackTraceable { + action := ActionConverter(permission.Action) + insufficientPermissionMessage = fmt.Sprintf("%s. Required permission: %s [%s]", insufficientPermissionMessage, + permission.Resource, action) + } + if !valid { + return respondErrorWithRequiredPermission(http.StatusForbidden, InsufficientPermissions, + "access forbidden: "+insufficientPermissionMessage, Permission{ + Resource: permission.Resource, + Action: permission.Action, + }) } + return nil } // WithRole filters request with valid role only @@ -386,19 +397,22 @@ func WithValidAudience() FilterOption { // WithValidScope filters request from a user with verified scope func WithValidScope(scope string) FilterOption { return func(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims) error { - err := iamClient.ValidateScope(claims, scope) - insufficientScopeMessage := ErrorCodeMapping[InsufficientScope] - if DevStackTraceable { - insufficientScopeMessage = fmt.Sprintf("%s. Required scope: %s", insufficientScopeMessage, - scope) - } - if err != nil { - return respondError(http.StatusForbidden, InsufficientScope, - "access forbidden: "+insufficientScopeMessage) - } + return withValidScope(scope, iamClient, claims) + } +} - return nil +func withValidScope(scope string, iamClient iam.Client, claims *iam.JWTClaims) error { + err := iamClient.ValidateScope(claims, scope) + insufficientScopeMessage := ErrorCodeMapping[InsufficientScope] + if DevStackTraceable { + insufficientScopeMessage = fmt.Sprintf("%s. Required scope: %s", insufficientScopeMessage, + scope) + } + if err != nil { + return respondError(http.StatusForbidden, InsufficientScope, + "access forbidden: "+insufficientScopeMessage) } + return nil } func validateSubdomainAgainstNamespace(host string, namespace string, excludedNamespaces []string) bool {