diff --git a/models/databasehelper.go b/models/databasehelper.go new file mode 100644 index 0000000..c709357 --- /dev/null +++ b/models/databasehelper.go @@ -0,0 +1,76 @@ +package models + +import ( + // "database/sql" + "errors" + "github.com/lib/pq" +) + +// UniqueConstraintError is returned when a uniqueness constraint is violated during an insert. +var UniqueConstraintError = errors.New("postgres: unique constraint violation") + +// QueryResultContainedMultipleRowsError is returned when a query unexpectedly returns more than one row. +var QueryResultContainedMultipleRowsError = errors.New("query result unexpectedly contained multiple rows") + +// QueryResultContainedNoRowsError is returned when a query unexpectedly returns no rows. +var QueryResultContainedNoRowsError = errors.New("query result unexpectedly contained no rows") + +func convertPostgresError(err error) error { + const uniqueConstraintErrorCode = "23505" + + if postgresErr, ok := err.(*pq.Error); ok { + if postgresErr.Code == uniqueConstraintErrorCode { + return UniqueConstraintError + } + } + + return err +} + +func (db *DB) execOneResult(sqlQuery string, object interface{}, args ...interface{}) error { + + rows, err := db.Query(sqlQuery, args...) + if err != nil { + return convertPostgresError(err) + } + defer rows.Close() + + foundResult := false + for rows.Next() { + + if foundResult { + return QueryResultContainedMultipleRowsError + } + + if err := rows.Scan(object); err != nil { + return convertPostgresError(err) + } + + foundResult = true + } + + if !foundResult { + return QueryResultContainedNoRowsError + } + + if err := rows.Err(); err != nil { + return convertPostgresError(err) + } + + return nil +} + +func (db *DB) execNoResults(sqlQuery string, args ...interface{}) (int64, error) { + + res, err := db.Exec(sqlQuery, args...) + if err != nil { + return 0, convertPostgresError(err) + } + + numAffected, err := res.RowsAffected() + if err != nil { + return 0, convertPostgresError(err) + } + + return numAffected, nil +} diff --git a/models/datastore.go b/models/datastore.go index d76600e..5af8a58 100644 --- a/models/datastore.go +++ b/models/datastore.go @@ -2,20 +2,8 @@ package models import ( "database/sql" - "errors" - - "github.com/lib/pq" ) -// UniqueConstraintError is returned when a uniqueness constraint is violated during an insert. -var UniqueConstraintError = errors.New("postgres: unique constraint violation") - -// QueryResultContainedMultipleRowsError is returned when a query unexpectedly returns more than one row. -var QueryResultContainedMultipleRowsError = errors.New("query result unexpectedly contained multiple rows") - -// QueryResultContainedNoRowsError is returned when a query unexpectedly returns no rows. -var QueryResultContainedNoRowsError = errors.New("query result unexpectedly contained no rows") - // ConnectToDatabase also pings the database to ensure a working connection. func ConnectToDatabase(databaseUrl string) (*DB, error) { tempDb, err := sql.Open("postgres", databaseUrl) @@ -41,15 +29,3 @@ type Datastore interface { type DB struct { *sql.DB } - -func convertPostgresError(err error) error { - const uniqueConstraintErrorCode = "23505" - - if postgresErr, ok := err.(*pq.Error); ok { - if postgresErr.Code == uniqueConstraintErrorCode { - return UniqueConstraintError - } - } - - return err -} diff --git a/models/note.go b/models/note.go index ae1f516..7b3b47e 100644 --- a/models/note.go +++ b/models/note.go @@ -27,31 +27,9 @@ func (db *DB) StoreNewNote( VALUES ($1, $2, $3) RETURNING id` - rows, err := db.Query(sqlQuery, authorId, content, creationTime) - if err != nil { - return 0, convertPostgresError(err) - } - defer rows.Close() - var noteId int64 = 0 - for rows.Next() { - - if noteId != 0 { - return 0, QueryResultContainedMultipleRowsError - } - - if err := rows.Scan(¬eId); err != nil { - return 0, convertPostgresError(err) - } + if err := db.execOneResult(sqlQuery, ¬eId, authorId, content, creationTime); err != nil { + return 0, err } - - if noteId == 0 { - return 0, QueryResultContainedNoRowsError - } - - if err := rows.Err(); err != nil { - return 0, convertPostgresError(err) - } - return NoteId(noteId), nil } diff --git a/models/note_category.go b/models/note_category.go index e7992c9..4470480 100644 --- a/models/note_category.go +++ b/models/note_category.go @@ -21,6 +21,7 @@ var categoryStrings = [...]string{ } var CannotDeserializeNoteCategoryStringError = errors.New("String does not correspond to a Note Category") +var NoteAlreadyContainsCategoryError = errors.New("NoteId already has a category stored for it") func DeserializeNoteCategory(input string) (NoteCategory, error) { for i := 0; i < len(categoryStrings); i++ { @@ -48,14 +49,11 @@ func (db *DB) StoreNewNoteCategoryRelationship( INSERT INTO note_to_category_relationship (note_id, category) VALUES ($1, $2)` - rows, err := db.Query(sqlQuery, int64(noteId), category.String()) - if err != nil { - return convertPostgresError(err) - } - defer rows.Close() - - if err := rows.Err(); err != nil { - return convertPostgresError(err) + if _, err := db.execNoResults(sqlQuery, int64(noteId), category.String()); err != nil { + if err == UniqueConstraintError { + return NoteAlreadyContainsCategoryError + } + return err } return nil diff --git a/models/user.go b/models/user.go index b90198f..338d979 100644 --- a/models/user.go +++ b/models/user.go @@ -51,14 +51,12 @@ func (db *DB) StoreNewUser( INSERT INTO app_user (display_name, email_address, password, creation_time) VALUES ($1, $2, $3, $4)` - rows, err := db.Query(sqlQuery, displayName, emailAddress.String(), hashedPassword, creationTime) - if err != nil { - return convertPostgresError(err) - } - defer rows.Close() + if _, err := db.execNoResults(sqlQuery, displayName, emailAddress.String(), hashedPassword, creationTime); err != nil { + if err == UniqueConstraintError { + return EmailAddressAlreadyInUseError + } - if err := rows.Err(); err != nil { - return convertPostgresError(err) + return err } return nil @@ -69,25 +67,10 @@ func (db *DB) AuthenticateUserCredentials(emailAddress *EmailAddress, password s SELECT password FROM app_user WHERE email_address = $1` - rows, err := db.Query(sqlQuery, emailAddress.String()) - if err != nil { - return convertPostgresError(err) - } - defer rows.Close() - var storedHashedPassword []byte - for rows.Next() { - if storedHashedPassword != nil { - return QueryResultContainedMultipleRowsError - } - if err := rows.Scan(&storedHashedPassword); err != nil { - return err - } - } - - if storedHashedPassword == nil { - return QueryResultContainedNoRowsError + if err := db.execOneResult(sqlQuery, &storedHashedPassword, emailAddress.String()); err != nil { + return err } if err := bcrypt.CompareHashAndPassword( @@ -109,25 +92,12 @@ func (db *DB) GetIdForUserWithEmailAddress(emailAddress *EmailAddress) (UserId, SELECT id FROM app_user WHERE email_address = $1` - rows, err := db.Query(sqlQuery, emailAddress.String()) - if err != nil { - return 0, convertPostgresError(err) - } - defer rows.Close() - var userId int64 - for rows.Next() { - if userId != 0 { - return 0, QueryResultContainedMultipleRowsError + if err := db.execOneResult(sqlQuery, &userId, emailAddress.String()); err != nil { + if err == QueryResultContainedNoRowsError { + return 0, CredentialsNotAuthorizedError } - - if err := rows.Scan(&userId); err != nil { - return 0, err - } - } - - if userId == 0 { - return 0, QueryResultContainedNoRowsError + return 0, err } return UserId(userId), nil