From 8b5d63865f42aff8ff3f99466965672b3b6775e6 Mon Sep 17 00:00:00 2001 From: Alberto Adami Date: Thu, 5 Mar 2026 20:57:36 +0100 Subject: [PATCH] implement change password endpoint --- internal/dto/change_password_request.go | 6 ++ internal/errors/errors.go | 1 + internal/handlers/user_handler.go | 34 +++++++ internal/handlers/user_handler_test.go | 90 ++++++++++++++++++- internal/repositories/user_repository.go | 21 ++++- internal/repositories/user_repository_test.go | 44 +++++++-- internal/routes/routes.go | 1 + internal/services/auth_service_test.go | 39 +++++++- internal/services/user_service.go | 21 ++++- 9 files changed, 244 insertions(+), 13 deletions(-) create mode 100644 internal/dto/change_password_request.go diff --git a/internal/dto/change_password_request.go b/internal/dto/change_password_request.go new file mode 100644 index 0000000..084801b --- /dev/null +++ b/internal/dto/change_password_request.go @@ -0,0 +1,6 @@ +package dto + +type ChangePasswordRequest struct { + CurrentPassword string `json:"current_password" binding:"required,min=8,max=72` + NewPassword string `json:"new_password" binding:"required,min=8,max=72"` +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index c211adb..05b9eb6 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -9,4 +9,5 @@ var ( CredentialsInvalid = errors.New("invalid credentials") ErrInvalidToken = errors.New("invalid token") ErrGeneratingToken = errors.New("error generating token") + ErrInvalidPassword = errors.New("invalid password") ) diff --git a/internal/handlers/user_handler.go b/internal/handlers/user_handler.go index 3360793..3103472 100644 --- a/internal/handlers/user_handler.go +++ b/internal/handlers/user_handler.go @@ -78,3 +78,37 @@ func (u *UserHandler) GetCurrentUser(c *gin.Context) { Email: user.Email, }) } + +func (u *UserHandler) ChangePassword(c *gin.Context) { + var request dto.ChangePasswordRequest + + if err := c.ShouldBindJSON(&request); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + userId, ok := getUserIdFromContext(c) + if !ok { + c.Status(http.StatusNotFound) + return + } + + user, err := u.userService.GetUserById(userId) + if err != nil || user == nil { + c.Status(http.StatusInternalServerError) + return + } + + err = u.userService.ChangePassword(user, request.CurrentPassword, request.NewPassword) + if err != nil { + switch err { + case errors.CredentialsInvalid: + c.JSON(http.StatusUnauthorized, dto.NewErrorResponse(err.Error(), "Current password is incorrect")) + default: + c.Status(http.StatusInternalServerError) + } + return + } + c.Status(http.StatusNoContent) + +} diff --git a/internal/handlers/user_handler_test.go b/internal/handlers/user_handler_test.go index 01844e4..9f43bdf 100644 --- a/internal/handlers/user_handler_test.go +++ b/internal/handlers/user_handler_test.go @@ -1,6 +1,7 @@ package handlers import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -8,6 +9,8 @@ import ( "encoding/json" + "github.com/albertoadami/nestled/internal/crypto" + "github.com/albertoadami/nestled/internal/dto" "github.com/albertoadami/nestled/internal/errors" "github.com/albertoadami/nestled/internal/model" @@ -20,8 +23,9 @@ import ( ) type mockUserService struct { - createUserFn func(req *dto.CreateUserRequest) (uuid.UUID, error) - getByIdFn func(id uuid.UUID) (*model.User, error) + createUserFn func(req *dto.CreateUserRequest) (uuid.UUID, error) + getByIdFn func(id uuid.UUID) (*model.User, error) + returnUpdateError bool } func (m *mockUserService) CreateUser(req *dto.CreateUserRequest) (uuid.UUID, error) { @@ -32,6 +36,14 @@ func (m *mockUserService) GetUserById(id uuid.UUID) (*model.User, error) { return m.getByIdFn(id) } +func (m *mockUserService) ChangePassword(user *model.User, currentPassword string, newPassword string) error { + if m.returnUpdateError { + return errors.CredentialsInvalid + } else { + return nil + } +} + func setupUserRouter(mockService *mockUserService) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() @@ -46,6 +58,7 @@ func setUpUserProfileRouter(mockService *mockUserService, userId uuid.UUID) *gin handler := NewUserHandler(mockService, zap.NewNop()) // apply mock authentication as middleware before the handler router.GET("/api/v1/users/me", testhelpers.MockAuthentication(userId), handler.GetCurrentUser) + router.PATCH("/api/v1/users/me/password", testhelpers.MockAuthentication(userId), handler.ChangePassword) return router } @@ -155,3 +168,76 @@ func TestUserProfileSuccessfully(t *testing.T) { assert.NoError(t, err) assert.JSONEq(t, string(expectedJSON), w.Body.String()) } + +func TestChangePasswordCorrectly(t *testing.T) { + + userId := uuid.New() + passwordHash, _ := crypto.HashPassword("oldpassword") + + mockService := &mockUserService{ + getByIdFn: func(id uuid.UUID) (*model.User, error) { + if id == userId { + return &model.User{ + Id: userId, + Username: "test", + Email: "test@test.it", + FirstName: "Test", + LastName: "User", + PasswordHash: passwordHash, + }, nil + } + return nil, nil + }, + } + router := setUpUserProfileRouter(mockService, userId) + + passwordRequest := &dto.ChangePasswordRequest{ + CurrentPassword: "oldpassword", + NewPassword: "newpassword123", + } + body, _ := json.Marshal(passwordRequest) + + req, _ := http.NewRequest("PATCH", "/api/v1/users/me/password", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNoContent, w.Code) + +} + +func TestChangePasswordInvalidCurrentPassword(t *testing.T) { + + userId := uuid.New() + mockService := &mockUserService{ + getByIdFn: func(id uuid.UUID) (*model.User, error) { + if id == userId { + return &model.User{ + Id: userId, + Username: "test", + Email: "test@test.it", + FirstName: "Test", + LastName: "User", + PasswordHash: "blablah", + }, nil + } + return nil, nil + }, + returnUpdateError: true, + } + router := setUpUserProfileRouter(mockService, userId) + + passwordRequest := &dto.ChangePasswordRequest{ + CurrentPassword: "wrong_password", + NewPassword: "newpassword123", + } + body, _ := json.Marshal(passwordRequest) + + req, _ := http.NewRequest("PATCH", "/api/v1/users/me/password", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + +} diff --git a/internal/repositories/user_repository.go b/internal/repositories/user_repository.go index eea0f24..8797ebc 100644 --- a/internal/repositories/user_repository.go +++ b/internal/repositories/user_repository.go @@ -12,7 +12,8 @@ import ( ) type UserRepository interface { - CreateUser(user *model.User) (uuid.UUID, error) + Create(user *model.User) (uuid.UUID, error) + Update(user *model.User) error GetUserByUsername(username string) (*model.User, error) GetUserById(id uuid.UUID) (*model.User, error) } @@ -25,7 +26,7 @@ func NewUserRepository(db *sqlx.DB) UserRepository { return &userRepository{db: db} } -func (r *userRepository) CreateUser(user *model.User) (uuid.UUID, error) { +func (r *userRepository) Create(user *model.User) (uuid.UUID, error) { query := `INSERT INTO users (id, username, first_name, last_name, email, password_hash, status) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id` @@ -85,3 +86,19 @@ func (r *userRepository) GetUserById(id uuid.UUID) (*model.User, error) { } return &user, nil } + +func (r *userRepository) Update(user *model.User) error { + query := `UPDATE users + SET username = $1, first_name = $2, last_name = $3, email = $4, password_hash = $5, status = $6, updated_at = NOW() + WHERE id = $7` + _, err := r.db.Exec(query, + user.Username, + user.FirstName, + user.LastName, + user.Email, + user.PasswordHash, + user.Status, + user.Id, + ) + return err +} diff --git a/internal/repositories/user_repository_test.go b/internal/repositories/user_repository_test.go index fa3fb9d..9252014 100644 --- a/internal/repositories/user_repository_test.go +++ b/internal/repositories/user_repository_test.go @@ -41,7 +41,7 @@ func TestCreateUserSucessfully(t *testing.T) { user := createTestUser() - id, err := userRepo.CreateUser(user) + id, err := userRepo.Create(user) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -62,14 +62,14 @@ func TestCreateUserFailedDueToDuplicateUsername(t *testing.T) { userRepo := NewUserRepository(db) user := createTestUser() - _, err := userRepo.CreateUser(user) + _, err := userRepo.Create(user) if err != nil { t.Fatalf("expected no error, got %v", err) } user.Email = "test-duplicated@test.it" user.Id = uuid.New() - _, err = userRepo.CreateUser(user) + _, err = userRepo.Create(user) if err == nil { t.Fatal("expected error, got nil") } @@ -89,14 +89,14 @@ func TestCreateUserFailedDueToDuplicateEmail(t *testing.T) { userRepo := NewUserRepository(db) user := createTestUser() - _, err := userRepo.CreateUser(user) + _, err := userRepo.Create(user) if err != nil { t.Fatalf("expected no error, got %v", err) } user.Username = "johndoe-duplicated" user.Id = uuid.New() - _, err = userRepo.CreateUser(user) + _, err = userRepo.Create(user) if err == nil { t.Fatal("expected error, got nil") } @@ -114,7 +114,7 @@ func TestGetUserByUsernameSucessfully(t *testing.T) { userRepo := NewUserRepository(db) user := createTestUser() - _, err := userRepo.CreateUser(user) + _, err := userRepo.Create(user) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -151,7 +151,7 @@ func TestGetUserByIdSucessfully(t *testing.T) { userRepo := NewUserRepository(db) user := createTestUser() - _, err := userRepo.CreateUser(user) + _, err := userRepo.Create(user) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -179,3 +179,33 @@ func TestGetUserByIdFailedDueToNonExistingUser(t *testing.T) { assert.Nil(t, err, "expected err to be nil") assert.Nil(t, result, "expected result to be nil") } + +func TestUpdateUserSuccessfully(t *testing.T) { + + db, terminate := testhelpers.SetupPostgres(t) + defer terminate() + truncateUsers(t, db) + + userRepo := NewUserRepository(db) + user := createTestUser() + _, err := userRepo.Create(user) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + user.FirstName = "UpdatedFirstName" + err = userRepo.Update(user) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + retrievedUser, err := userRepo.GetUserById(user.Id) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if retrievedUser.FirstName != "UpdatedFirstName" { + t.Fatalf("expected first name %v, got %v", "UpdatedFirstName", retrievedUser.FirstName) + } + +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 29bfb29..9f7e7be 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -19,4 +19,5 @@ func SetupRoutes(r *gin.Engine, userHandler *handlers.UserHandler, healthHandler apiGroup.POST("/register", userHandler.RegisterUser) apiGroup.POST("/auth/token", authHandler.GenerateToken) protected.GET("/users/me", userHandler.GetCurrentUser) + protected.PATCH("/users/me/password", userHandler.ChangePassword) } diff --git a/internal/services/auth_service_test.go b/internal/services/auth_service_test.go index cf8a56a..4a1e63c 100644 --- a/internal/services/auth_service_test.go +++ b/internal/services/auth_service_test.go @@ -18,7 +18,7 @@ type mockUserRepo struct { getFnById func(id uuid.UUID) (*model.User, error) } -func (m *mockUserRepo) CreateUser(user *model.User) (uuid.UUID, error) { +func (m *mockUserRepo) Create(user *model.User) (uuid.UUID, error) { return uuid.Nil, nil } @@ -30,6 +30,10 @@ func (m *mockUserRepo) GetUserByUsername(username string) (*model.User, error) { return m.getFn(username) } +func (m *mockUserRepo) Update(user *model.User) error { + return nil +} + var tokenManager = auth.NewTokenManager(config.JWTConfig{Secret: "secret", Expiration: 1}) func TestGenerateToken_UserNotFound(t *testing.T) { @@ -81,3 +85,36 @@ func TestGenerateToken_Success(t *testing.T) { assert.NotNil(t, token) assert.NotEmpty(t, token.Value) } + +func TestUpdatePasswordInvalidCurrentPassword(t *testing.T) { + hash, _ := crypto.HashPassword("current") + user := &model.User{Id: uuid.New(), PasswordHash: hash} + + mockRepo := &mockUserRepo{ + getFnById: func(userId uuid.UUID) (*model.User, error) { + return user, nil + }, + } + service := NewUserService(mockRepo) + + err := service.ChangePassword(user, "wrong", "newpass") + assert.ErrorIs(t, err, errors.ErrInvalidPassword) +} + +func TestUpdatePasswordSuccess(t *testing.T) { + hash, _ := crypto.HashPassword("current") + user := &model.User{Id: uuid.New(), PasswordHash: hash} + + mockRepo := &mockUserRepo{ + getFnById: func(userId uuid.UUID) (*model.User, error) { + return user, nil + }, + } + service := NewUserService(mockRepo) + + err := service.ChangePassword(user, "current", "newpass") + assert.NoError(t, err) + + // Verify that the password hash has been updated + assert.NotEqual(t, hash, user.PasswordHash) +} diff --git a/internal/services/user_service.go b/internal/services/user_service.go index 4d08619..0fae8e8 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -3,6 +3,7 @@ package services import ( "github.com/albertoadami/nestled/internal/crypto" "github.com/albertoadami/nestled/internal/dto" + "github.com/albertoadami/nestled/internal/errors" "github.com/albertoadami/nestled/internal/model" "github.com/albertoadami/nestled/internal/repositories" "github.com/google/uuid" @@ -11,6 +12,7 @@ import ( type UserService interface { CreateUser(request *dto.CreateUserRequest) (uuid.UUID, error) GetUserById(id uuid.UUID) (*model.User, error) + ChangePassword(user *model.User, currentPassword string, newPassword string) error } type userService struct { @@ -40,10 +42,27 @@ func (s *userService) CreateUser(request *dto.CreateUserRequest) (uuid.UUID, err Status: model.UserStatusPending, } - return s.userRepository.CreateUser(user) + return s.userRepository.Create(user) } func (s *userService) GetUserById(id uuid.UUID) (*model.User, error) { return s.userRepository.GetUserById(id) } + +func (s *userService) ChangePassword(user *model.User, currentPassword string, newPassword string) error { + if !crypto.CheckPassword(currentPassword, user.PasswordHash) { + return errors.ErrInvalidPassword + } + + // generate the new hash for the new password + hashedPassword, err := crypto.HashPassword(newPassword) + if err != nil { + return err + } + + user.PasswordHash = hashedPassword + + return s.userRepository.Update(user) + +}