Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/dto/change_password_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package dto

type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password" binding:"required,min=8,max=72`
NewPassword string `json:"new_password" binding:"required,min=8,max=72"`
}
1 change: 1 addition & 0 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ var (
CredentialsInvalid = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid token")
ErrGeneratingToken = errors.New("error generating token")
ErrInvalidPassword = errors.New("invalid password")
)
34 changes: 34 additions & 0 deletions internal/handlers/user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,37 @@ func (u *UserHandler) GetCurrentUser(c *gin.Context) {
Email: user.Email,
})
}

func (u *UserHandler) ChangePassword(c *gin.Context) {
var request dto.ChangePasswordRequest

if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}

userId, ok := getUserIdFromContext(c)
if !ok {
c.Status(http.StatusNotFound)
return
}

user, err := u.userService.GetUserById(userId)
if err != nil || user == nil {
c.Status(http.StatusInternalServerError)
return
}

err = u.userService.ChangePassword(user, request.CurrentPassword, request.NewPassword)
if err != nil {
switch err {
case errors.CredentialsInvalid:
c.JSON(http.StatusUnauthorized, dto.NewErrorResponse(err.Error(), "Current password is incorrect"))
default:
c.Status(http.StatusInternalServerError)
}
return
}
c.Status(http.StatusNoContent)

}
90 changes: 88 additions & 2 deletions internal/handlers/user_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package handlers

import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"

"encoding/json"

"github.com/albertoadami/nestled/internal/crypto"

"github.com/albertoadami/nestled/internal/dto"
"github.com/albertoadami/nestled/internal/errors"
"github.com/albertoadami/nestled/internal/model"
Expand All @@ -20,8 +23,9 @@ import (
)

type mockUserService struct {
createUserFn func(req *dto.CreateUserRequest) (uuid.UUID, error)
getByIdFn func(id uuid.UUID) (*model.User, error)
createUserFn func(req *dto.CreateUserRequest) (uuid.UUID, error)
getByIdFn func(id uuid.UUID) (*model.User, error)
returnUpdateError bool
}

func (m *mockUserService) CreateUser(req *dto.CreateUserRequest) (uuid.UUID, error) {
Expand All @@ -32,6 +36,14 @@ func (m *mockUserService) GetUserById(id uuid.UUID) (*model.User, error) {
return m.getByIdFn(id)
}

func (m *mockUserService) ChangePassword(user *model.User, currentPassword string, newPassword string) error {
if m.returnUpdateError {
return errors.CredentialsInvalid
} else {
return nil
}
}

func setupUserRouter(mockService *mockUserService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
Expand All @@ -46,6 +58,7 @@ func setUpUserProfileRouter(mockService *mockUserService, userId uuid.UUID) *gin
handler := NewUserHandler(mockService, zap.NewNop())
// apply mock authentication as middleware before the handler
router.GET("/api/v1/users/me", testhelpers.MockAuthentication(userId), handler.GetCurrentUser)
router.PATCH("/api/v1/users/me/password", testhelpers.MockAuthentication(userId), handler.ChangePassword)
return router
}

Expand Down Expand Up @@ -155,3 +168,76 @@ func TestUserProfileSuccessfully(t *testing.T) {
assert.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), w.Body.String())
}

func TestChangePasswordCorrectly(t *testing.T) {

userId := uuid.New()
passwordHash, _ := crypto.HashPassword("oldpassword")

mockService := &mockUserService{
getByIdFn: func(id uuid.UUID) (*model.User, error) {
if id == userId {
return &model.User{
Id: userId,
Username: "test",
Email: "test@test.it",
FirstName: "Test",
LastName: "User",
PasswordHash: passwordHash,
}, nil
}
return nil, nil
},
}
router := setUpUserProfileRouter(mockService, userId)

passwordRequest := &dto.ChangePasswordRequest{
CurrentPassword: "oldpassword",
NewPassword: "newpassword123",
}
body, _ := json.Marshal(passwordRequest)

req, _ := http.NewRequest("PATCH", "/api/v1/users/me/password", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)

assert.Equal(t, http.StatusNoContent, w.Code)

}

func TestChangePasswordInvalidCurrentPassword(t *testing.T) {

userId := uuid.New()
mockService := &mockUserService{
getByIdFn: func(id uuid.UUID) (*model.User, error) {
if id == userId {
return &model.User{
Id: userId,
Username: "test",
Email: "test@test.it",
FirstName: "Test",
LastName: "User",
PasswordHash: "blablah",
}, nil
}
return nil, nil
},
returnUpdateError: true,
}
router := setUpUserProfileRouter(mockService, userId)

passwordRequest := &dto.ChangePasswordRequest{
CurrentPassword: "wrong_password",
NewPassword: "newpassword123",
}
body, _ := json.Marshal(passwordRequest)

req, _ := http.NewRequest("PATCH", "/api/v1/users/me/password", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)

}
21 changes: 19 additions & 2 deletions internal/repositories/user_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
)

type UserRepository interface {
CreateUser(user *model.User) (uuid.UUID, error)
Create(user *model.User) (uuid.UUID, error)
Update(user *model.User) error
GetUserByUsername(username string) (*model.User, error)
GetUserById(id uuid.UUID) (*model.User, error)
}
Expand All @@ -25,7 +26,7 @@ func NewUserRepository(db *sqlx.DB) UserRepository {
return &userRepository{db: db}
}

func (r *userRepository) CreateUser(user *model.User) (uuid.UUID, error) {
func (r *userRepository) Create(user *model.User) (uuid.UUID, error) {
query := `INSERT INTO users (id, username, first_name, last_name, email, password_hash, status)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`
Expand Down Expand Up @@ -85,3 +86,19 @@ func (r *userRepository) GetUserById(id uuid.UUID) (*model.User, error) {
}
return &user, nil
}

func (r *userRepository) Update(user *model.User) error {
query := `UPDATE users
SET username = $1, first_name = $2, last_name = $3, email = $4, password_hash = $5, status = $6, updated_at = NOW()
WHERE id = $7`
_, err := r.db.Exec(query,
user.Username,
user.FirstName,
user.LastName,
user.Email,
user.PasswordHash,
user.Status,
user.Id,
)
return err
}
44 changes: 37 additions & 7 deletions internal/repositories/user_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestCreateUserSucessfully(t *testing.T) {

user := createTestUser()

id, err := userRepo.CreateUser(user)
id, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
Expand All @@ -62,14 +62,14 @@ func TestCreateUserFailedDueToDuplicateUsername(t *testing.T) {
userRepo := NewUserRepository(db)

user := createTestUser()
_, err := userRepo.CreateUser(user)
_, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

user.Email = "test-duplicated@test.it"
user.Id = uuid.New()
_, err = userRepo.CreateUser(user)
_, err = userRepo.Create(user)
if err == nil {
t.Fatal("expected error, got nil")
}
Expand All @@ -89,14 +89,14 @@ func TestCreateUserFailedDueToDuplicateEmail(t *testing.T) {
userRepo := NewUserRepository(db)

user := createTestUser()
_, err := userRepo.CreateUser(user)
_, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

user.Username = "johndoe-duplicated"
user.Id = uuid.New()
_, err = userRepo.CreateUser(user)
_, err = userRepo.Create(user)
if err == nil {
t.Fatal("expected error, got nil")
}
Expand All @@ -114,7 +114,7 @@ func TestGetUserByUsernameSucessfully(t *testing.T) {

userRepo := NewUserRepository(db)
user := createTestUser()
_, err := userRepo.CreateUser(user)
_, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
Expand Down Expand Up @@ -151,7 +151,7 @@ func TestGetUserByIdSucessfully(t *testing.T) {

userRepo := NewUserRepository(db)
user := createTestUser()
_, err := userRepo.CreateUser(user)
_, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
Expand Down Expand Up @@ -179,3 +179,33 @@ func TestGetUserByIdFailedDueToNonExistingUser(t *testing.T) {
assert.Nil(t, err, "expected err to be nil")
assert.Nil(t, result, "expected result to be nil")
}

func TestUpdateUserSuccessfully(t *testing.T) {

db, terminate := testhelpers.SetupPostgres(t)
defer terminate()
truncateUsers(t, db)

userRepo := NewUserRepository(db)
user := createTestUser()
_, err := userRepo.Create(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

user.FirstName = "UpdatedFirstName"
err = userRepo.Update(user)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

retrievedUser, err := userRepo.GetUserById(user.Id)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if retrievedUser.FirstName != "UpdatedFirstName" {
t.Fatalf("expected first name %v, got %v", "UpdatedFirstName", retrievedUser.FirstName)
}

}
1 change: 1 addition & 0 deletions internal/routes/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ func SetupRoutes(r *gin.Engine, userHandler *handlers.UserHandler, healthHandler
apiGroup.POST("/register", userHandler.RegisterUser)
apiGroup.POST("/auth/token", authHandler.GenerateToken)
protected.GET("/users/me", userHandler.GetCurrentUser)
protected.PATCH("/users/me/password", userHandler.ChangePassword)
}
39 changes: 38 additions & 1 deletion internal/services/auth_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type mockUserRepo struct {
getFnById func(id uuid.UUID) (*model.User, error)
}

func (m *mockUserRepo) CreateUser(user *model.User) (uuid.UUID, error) {
func (m *mockUserRepo) Create(user *model.User) (uuid.UUID, error) {
return uuid.Nil, nil
}

Expand All @@ -30,6 +30,10 @@ func (m *mockUserRepo) GetUserByUsername(username string) (*model.User, error) {
return m.getFn(username)
}

func (m *mockUserRepo) Update(user *model.User) error {
return nil
}

var tokenManager = auth.NewTokenManager(config.JWTConfig{Secret: "secret", Expiration: 1})

func TestGenerateToken_UserNotFound(t *testing.T) {
Expand Down Expand Up @@ -81,3 +85,36 @@ func TestGenerateToken_Success(t *testing.T) {
assert.NotNil(t, token)
assert.NotEmpty(t, token.Value)
}

func TestUpdatePasswordInvalidCurrentPassword(t *testing.T) {
hash, _ := crypto.HashPassword("current")
user := &model.User{Id: uuid.New(), PasswordHash: hash}

mockRepo := &mockUserRepo{
getFnById: func(userId uuid.UUID) (*model.User, error) {
return user, nil
},
}
service := NewUserService(mockRepo)

err := service.ChangePassword(user, "wrong", "newpass")
assert.ErrorIs(t, err, errors.ErrInvalidPassword)
}

func TestUpdatePasswordSuccess(t *testing.T) {
hash, _ := crypto.HashPassword("current")
user := &model.User{Id: uuid.New(), PasswordHash: hash}

mockRepo := &mockUserRepo{
getFnById: func(userId uuid.UUID) (*model.User, error) {
return user, nil
},
}
service := NewUserService(mockRepo)

err := service.ChangePassword(user, "current", "newpass")
assert.NoError(t, err)

// Verify that the password hash has been updated
assert.NotEqual(t, hash, user.PasswordHash)
}
Loading
Loading