diff --git a/handlers/handlers.go b/handlers/handlers.go index 2bb2d79..5f62e1e 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "errors" "fmt" "html/template" "net/http" @@ -20,6 +21,11 @@ const cerealNotesCookieName = "CerealNotesToken" const baseTemplateName = "base" const baseTemplateFile = "templates/base.tmpl" +var EmptyNoteContentError error = errors.New("Note content cannot be empty or just whitespace") +var NotYourNoteError error = errors.New("You are not the other of this note and therer for cannot preform this action") +var NoChangeError error = errors.New("The action you are trying to prefrom doesn't change anything") +var InvalidMethodError error = errors.New("This endpoint does not except that http method") + // JwtTokenClaim contains all claims required for authentication, including the standard JWT claims. type JwtTokenClaim struct { models.UserId `json:"userId"` @@ -31,9 +37,72 @@ type Environment struct { TokenSigningKey []byte } +type AuthenticatedRequestHandlerType func( + *Environment, + http.ResponseWriter, + *http.Request, + models.UserId, +) (error, int) + +type UnauthenticatedEndpointHandlerType func( + *Environment, + http.ResponseWriter, + *http.Request, +) (error, int) + +// Wrappers +func AuthenticateOrRedirect( + env *Environment, + authenticatedHandlerFunc AuthenticatedRequestHandlerType, + redirectPath string, +) http.HandlerFunc { + return func(responseWriter http.ResponseWriter, request *http.Request) { + if userId, err := getUserIdFromJwtToken(env, request); err != nil { + switch request.Method { + // If not logged in, redirect to login page + case http.MethodGet: + http.Redirect( + responseWriter, + request, + redirectPath, + http.StatusTemporaryRedirect) + return + default: + respondWithMethodNotAllowed(responseWriter, http.MethodGet) + } + } else { + if err, errCode := authenticatedHandlerFunc(env, responseWriter, request, userId); err != nil { + http.Error(responseWriter, err.Error(), errCode) + return + } + } + } +} + +func AuthenticateOrReturnUnauthorized( + env *Environment, + authenticatedHandlerFunc AuthenticatedRequestHandlerType, +) http.HandlerFunc { + return func(responseWriter http.ResponseWriter, request *http.Request) { + + if userId, err := getUserIdFromJwtToken(env, request); err != nil { + responseWriter.Header().Set("WWW-Authenticate", `Bearer realm="`+request.URL.Path+`"`) + http.Error(responseWriter, err.Error(), http.StatusUnauthorized) + } else { + if err, errCode := authenticatedHandlerFunc(env, responseWriter, request, userId); err != nil { + http.Error(responseWriter, err.Error(), errCode) + return + } + } + } +} + func WrapUnauthenticatedEndpoint(env *Environment, handler UnauthenticatedEndpointHandlerType) http.HandlerFunc { return func(responseWriter http.ResponseWriter, request *http.Request) { - handler(env, responseWriter, request) + if err, errCode := handler(env, responseWriter, request); err != nil { + http.Error(responseWriter, err.Error(), errCode) + return + } } } @@ -45,7 +114,7 @@ func HandleLoginOrSignupPageRequest( env *Environment, responseWriter http.ResponseWriter, request *http.Request, -) { +) (error, int) { switch request.Method { case http.MethodGet: if _, err := getUserIdFromJwtToken(env, request); err == nil { @@ -54,27 +123,29 @@ func HandleLoginOrSignupPageRequest( request, paths.HomePage, http.StatusTemporaryRedirect) - return + return nil, 0 } parsedTemplate, err := template.ParseFiles(baseTemplateFile, "templates/login_or_signup.tmpl") if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } parsedTemplate.ExecuteTemplate(responseWriter, baseTemplateName, nil) + return nil, 0 + default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet) + return respondWithMethodNotAllowed(responseWriter, http.MethodGet) } } +// API func HandleUserApiRequest( env *Environment, responseWriter http.ResponseWriter, request *http.Request, -) { +) (error, int) { type SignupForm struct { DisplayName string `json:"displayName"` EmailAddress string `json:"emailAddress"` @@ -86,8 +157,7 @@ func HandleUserApiRequest( signupForm := new(SignupForm) if err := json.NewDecoder(request.Body).Decode(signupForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } var statusCode int @@ -99,8 +169,7 @@ func HandleUserApiRequest( if err == models.EmailAddressAlreadyInUseError { statusCode = http.StatusConflict } else { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } } else { statusCode = http.StatusCreated @@ -108,31 +177,32 @@ func HandleUserApiRequest( responseWriter.WriteHeader(statusCode) + return nil, 0 + case http.MethodGet: if _, err := getUserIdFromJwtToken(env, request); err != nil { - http.Error(responseWriter, err.Error(), http.StatusUnauthorized) - return + return err, http.StatusUnauthorized } usersById, err := env.Db.GetAllUsersById() if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } usersByIdJson, err := usersById.ToJson() if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.Header().Set("Content-Type", "application/json") responseWriter.WriteHeader(http.StatusOK) fmt.Fprint(responseWriter, string(usersByIdJson)) + return nil, 0 + default: - respondWithMethodNotAllowed(responseWriter, http.MethodPost, http.MethodGet) + return respondWithMethodNotAllowed(responseWriter, http.MethodPost, http.MethodGet) } } @@ -142,7 +212,7 @@ func HandleSessionApiRequest( env *Environment, responseWriter http.ResponseWriter, request *http.Request, -) { +) (error, int) { type LoginForm struct { EmailAddress string `json:"emailAddress"` Password string `json:"password"` @@ -153,8 +223,7 @@ func HandleSessionApiRequest( loginForm := new(LoginForm) if err := json.NewDecoder(request.Body).Decode(loginForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusBadRequest } if err := env.Db.AuthenticateUserCredentials( @@ -165,22 +234,19 @@ func HandleSessionApiRequest( if err == models.CredentialsNotAuthorizedError { statusCode = http.StatusUnauthorized } - http.Error(responseWriter, err.Error(), statusCode) - return + return err, statusCode } // Set our cookie to have a valid JWT Token as the value { userId, err := env.Db.GetIdForUserWithEmailAddress(models.NewEmailAddress(loginForm.EmailAddress)) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } token, err := CreateTokenAsString(env, userId, credentialTimeoutDuration) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } expirationTime := time.Now().Add(credentialTimeoutDuration) @@ -198,6 +264,8 @@ func HandleSessionApiRequest( responseWriter.WriteHeader(http.StatusCreated) + return nil, 0 + case http.MethodDelete: // Cookie will overwrite existing cookie then delete itself cookie := http.Cookie{ @@ -212,8 +280,10 @@ func HandleSessionApiRequest( responseWriter.WriteHeader(http.StatusOK) fmt.Fprint(responseWriter, "user successfully logged out") + return nil, 0 + default: - respondWithMethodNotAllowed( + return respondWithMethodNotAllowed( responseWriter, http.MethodPost, http.MethodDelete) @@ -225,17 +295,18 @@ func HandlePublicationApiRequest( responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, -) { +) (error, int) { switch request.Method { case http.MethodPost: if err := env.Db.PublishNotes(userId); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusCreated) + return nil, 0 + default: - respondWithMethodNotAllowed(responseWriter, http.MethodPost) + return respondWithMethodNotAllowed(responseWriter, http.MethodPost) } } @@ -244,20 +315,23 @@ func HandleNoteApiRequest( responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, -) { +) (error, int) { + + type NoteForm struct { + Content string `json:"content"` + } switch request.Method { + case http.MethodGet: publishedNotes, err := env.Db.GetAllPublishedNotesVisibleBy(userId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } myUnpublishedNotes, err := env.Db.GetMyUnpublishedNotes(userId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } // fmt.Println("number of published notes") @@ -278,8 +352,7 @@ func HandleNoteApiRequest( notesInJson, err := allNotes.ToJson() if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.Header().Set("Content-Type", "application/json") @@ -287,21 +360,18 @@ func HandleNoteApiRequest( fmt.Fprint(responseWriter, string(notesInJson)) + return nil, 0 + case http.MethodPost: - type NoteForm struct { - Content string `json:"content"` - } noteForm := new(NoteForm) if err := json.NewDecoder(request.Body).Decode(noteForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } if len(strings.TrimSpace(noteForm.Content)) == 0 { - http.Error(responseWriter, "Note content cannot be empty or just whitespace", http.StatusBadRequest) - return + return EmptyNoteContentError, http.StatusBadRequest } note := &models.Note{ @@ -312,8 +382,7 @@ func HandleNoteApiRequest( noteId, err := env.Db.StoreNewNote(note) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } type NoteResponse struct { @@ -322,8 +391,7 @@ func HandleNoteApiRequest( noteString, err := json.Marshal(&NoteResponse{NoteId: int64(noteId)}) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.Header().Set("Content-Type", "application/json") @@ -331,84 +399,77 @@ func HandleNoteApiRequest( fmt.Fprint(responseWriter, string(noteString)) + return nil, 0 + case http.MethodPut: - type NoteForm struct { - Id int64 `json:"id"` - Content string `json:"content"` - } + + id, err := strconv.ParseInt(request.URL.Query().Get("id"), 10, 64) + noteId := models.NoteId(id) noteForm := new(NoteForm) if err := json.NewDecoder(request.Body).Decode(noteForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } - if noteForm.Id < 1 { - http.Error(responseWriter, "Invalid Note Id", http.StatusBadRequest) - return + if noteId < 1 { + return models.NoNoteFoundError, http.StatusBadRequest } - noteId := models.NoteId(noteForm.Id) note, err := env.Db.GetNoteById(noteId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } if note.AuthorId != userId { - http.Error(responseWriter, "You can only edit notes of which you are the author", http.StatusUnauthorized) - return + return NotYourNoteError, http.StatusUnauthorized } content := strings.TrimSpace(noteForm.Content) if len(content) == 0 { - http.Error(responseWriter, "Note content cannot be empty or just whitespace", http.StatusBadRequest) - return + return EmptyNoteContentError, http.StatusBadRequest } if content == note.Content { - http.Error(responseWriter, "Note content is the same as existing content", http.StatusBadRequest) - return + return NoChangeError, http.StatusBadRequest } if err := env.Db.UpdateNoteContent(noteId, content); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusOK) + + return nil, 0 + case http.MethodDelete: id, err := strconv.ParseInt(request.URL.Query().Get("id"), 10, 64) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } noteId := models.NoteId(id) noteMap, err := env.Db.GetUsersNotes(userId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } if _, ok := noteMap[noteId]; !ok { - errorString := "No note with that Id written by you was found" - http.Error(responseWriter, errorString, http.StatusBadRequest) - return + return models.NoNoteFoundError, http.StatusInternalServerError } err = env.Db.DeleteNoteById(noteId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusOK) + return nil, 0 + default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet, http.MethodPost, http.MethodDelete) + return respondWithMethodNotAllowed(responseWriter, http.MethodGet, http.MethodPost, http.MethodDelete) } } @@ -417,7 +478,7 @@ func HandleCategoryApiRequest( responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, -) { +) (error, int) { switch request.Method { case http.MethodGet: @@ -426,8 +487,7 @@ func HandleCategoryApiRequest( category, err := env.Db.GetNoteCategory(noteId) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } type categoryObj struct { @@ -436,14 +496,16 @@ func HandleCategoryApiRequest( jsonValue, err := json.Marshal(&categoryObj{Category: category.String()}) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.Header().Set("Content-Type", "application/json") responseWriter.WriteHeader(http.StatusOK) fmt.Fprint(responseWriter, string(jsonValue)) + + return nil, 0 + case http.MethodPost: type CategoryForm struct { @@ -454,125 +516,72 @@ func HandleCategoryApiRequest( noteForm := new(CategoryForm) if err := json.NewDecoder(request.Body).Decode(noteForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } category, err := models.DeserializeCategory(strings.ToLower(noteForm.Category)) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } if err := env.Db.StoreNewNoteCategoryRelationship(models.NoteId(noteForm.NoteId), category); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusCreated) + return nil, 0 + case http.MethodPut: + id, err := strconv.ParseInt(request.URL.Query().Get("id"), 10, 64) + noteId := models.NoteId(id) + type CategoryForm struct { - NoteId int64 `json:"noteId"` Category string `json:"category"` } noteForm := new(CategoryForm) if err := json.NewDecoder(request.Body).Decode(noteForm); err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } category, err := models.DeserializeCategory(strings.ToLower(noteForm.Category)) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } - if err := env.Db.UpdateNoteCategory(models.NoteId(noteForm.NoteId), category); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + if err := env.Db.UpdateNoteCategory(noteId, category); err != nil { + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusOK) + return nil, 0 + case http.MethodDelete: id, err := strconv.ParseInt(request.URL.Query().Get("id"), 10, 64) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return + return err, http.StatusBadRequest } noteId := models.NoteId(id) if err := env.Db.DeleteNoteCategory(noteId); err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } responseWriter.WriteHeader(http.StatusOK) - default: - respondWithMethodNotAllowed(responseWriter, http.MethodPost, http.MethodPut, http.MethodDelete) - } - -} - -type AuthenticatedRequestHandlerType func( - *Environment, - http.ResponseWriter, - *http.Request, - models.UserId, -) - -type UnauthenticatedEndpointHandlerType func( - *Environment, - http.ResponseWriter, - *http.Request, -) + return nil, 0 -func AuthenticateOrRedirect( - env *Environment, - authenticatedHandlerFunc AuthenticatedRequestHandlerType, - redirectPath string, -) http.HandlerFunc { - return func(responseWriter http.ResponseWriter, request *http.Request) { - if userId, err := getUserIdFromJwtToken(env, request); err != nil { - switch request.Method { - // If not logged in, redirect to login page - case http.MethodGet: - http.Redirect( - responseWriter, - request, - redirectPath, - http.StatusTemporaryRedirect) - return - default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet) - } - } else { - authenticatedHandlerFunc(env, responseWriter, request, userId) - } + default: + return respondWithMethodNotAllowed(responseWriter, http.MethodPost, http.MethodPut, http.MethodDelete) } -} - -func AuthenticateOrReturnUnauthorized( - env *Environment, - authenticatedHandlerFunc AuthenticatedRequestHandlerType, -) http.HandlerFunc { - return func(responseWriter http.ResponseWriter, request *http.Request) { - if userId, err := getUserIdFromJwtToken(env, request); err != nil { - responseWriter.Header().Set("WWW-Authenticate", `Bearer realm="`+request.URL.Path+`"`) - http.Error(responseWriter, err.Error(), http.StatusUnauthorized) - } else { - authenticatedHandlerFunc(env, responseWriter, request, userId) - } - } } func RedirectToPathHandler( @@ -588,7 +597,8 @@ func RedirectToPathHandler( http.StatusTemporaryRedirect) return default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet) + err, errcode := respondWithMethodNotAllowed(responseWriter, http.MethodGet) + http.Error(responseWriter, err.Error(), errcode) } } } @@ -600,18 +610,20 @@ func HandleHomePageRequest( responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, -) { +) (error, int) { switch request.Method { case http.MethodGet: parsedTemplate, err := template.ParseFiles(baseTemplateFile, "templates/home.tmpl") if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } parsedTemplate.ExecuteTemplate(responseWriter, baseTemplateName, userId) + + return nil, 0 + default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet) + return respondWithMethodNotAllowed(responseWriter, http.MethodGet) } } @@ -620,19 +632,19 @@ func HandleNotesPageRequest( responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, -) { +) (error, int) { switch request.Method { case http.MethodGet: parsedTemplate, err := template.ParseFiles(baseTemplateFile, "templates/notes.tmpl") if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - return + return err, http.StatusInternalServerError } parsedTemplate.ExecuteTemplate(responseWriter, baseTemplateName, userId) + return nil, 0 default: - respondWithMethodNotAllowed(responseWriter, http.MethodGet) + return respondWithMethodNotAllowed(responseWriter, http.MethodGet) } } @@ -642,12 +654,11 @@ func respondWithMethodNotAllowed( responseWriter http.ResponseWriter, allowedMethod string, otherAllowedMethods ...string, -) { +) (error, int) { allowedMethods := append([]string{allowedMethod}, otherAllowedMethods...) allowedMethodsString := strings.Join(allowedMethods, ", ") responseWriter.Header().Set("Allow", allowedMethodsString) - statusCode := http.StatusMethodNotAllowed - http.Error(responseWriter, http.StatusText(statusCode), statusCode) + return InvalidMethodError, http.StatusMethodNotAllowed } diff --git a/integration_test.go b/integration_test.go index 8dcabba..fa89a08 100644 --- a/integration_test.go +++ b/integration_test.go @@ -202,7 +202,7 @@ func TestAuthenticatedFlow(t *testing.T) { // Update cateogry { questionCateogry := models.QUESTIONS - categoryForm := &CategoryForm{NoteId: noteIdAsInt, Category: questionCateogry.String()} + categoryForm := &CategoryForm{Category: questionCateogry.String()} jsonValue, _ := json.Marshal(categoryForm) mockDb.Func_UpdateNoteCategory = func(noteId models.NoteId, cat models.Category) error { @@ -213,7 +213,7 @@ func TestAuthenticatedFlow(t *testing.T) { return errors.New("Incorrect Data Arrived") } - resp, err := sendPutRequest(client, server.URL+paths.CategoryApi, "application/json", bytes.NewBuffer(jsonValue)) + resp, err := sendPutRequest(client, server.URL+paths.CategoryApi+"?id="+strconv.FormatInt(noteIdAsInt, 10), "application/json", bytes.NewBuffer(jsonValue)) ok(t, err) equals(t, http.StatusOK, resp.StatusCode) @@ -268,13 +268,12 @@ func TestAuthenticatedFlow(t *testing.T) { } noteForm := &NoteUpdateForm{ - NoteId: 3, Content: "anything else", } jsonValue, _ := json.Marshal(noteForm) - resp, err := sendPutRequest(client, server.URL+paths.NoteApi, "application/json", bytes.NewBuffer(jsonValue)) + resp, err := sendPutRequest(client, server.URL+paths.NoteApi+"?id="+strconv.FormatInt(3, 10), "application/json", bytes.NewBuffer(jsonValue)) ok(t, err) equals(t, http.StatusOK, resp.StatusCode)