diff --git a/errors/errors.go b/errors/errors.go index 984bc18..a8a6c91 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -1,41 +1,26 @@ package errors -type BadRequest struct { +type BaseError struct { Message string } -func (e *BadRequest) Error() string { +func (e *BaseError) Error() string { return e.Message } -type NotFound struct { - Message string -} - -func (e *NotFound) Error() string { - return e.Message -} - -type ServiceUnavailable struct { - Message string -} - -func (e *ServiceUnavailable) Error() string { - return e.Message -} - -type Forbidden struct { - Message string -} - -func (e *Forbidden) Error() string { - return e.Message -} - -type Unauthorized struct { - Message string -} - -func (e *Unauthorized) Error() string { - return e.Message -} +type BadRequest struct{ BaseError } +type NotFound struct{ BaseError } +type ServiceUnavailable struct{ BaseError } +type Forbidden struct{ BaseError } +type Unauthorized struct{ BaseError } +type MethodNotAllowed struct{ BaseError } +type Conflict struct{ BaseError } +type Gone struct{ BaseError } +type UnsupportedMediaType struct{ BaseError } +type UnprocessableEntity struct{ BaseError } +type TooManyRequests struct{ BaseError } +type InternalServerError struct{ BaseError } +type BadGateway struct{ BaseError } +type GatewayTimeout struct{ BaseError } +type RequestTimeout struct{ BaseError } +type NotImplemented struct{ BaseError } diff --git a/middlewares/error_handler.go b/middlewares/error_handler.go index 71b7735..a97d1b2 100644 --- a/middlewares/error_handler.go +++ b/middlewares/error_handler.go @@ -20,67 +20,67 @@ func (e *ErrorHandler) Wrap(handler func(w http.ResponseWriter, r *http.Request) var serviceUnavailable *serviceErrors.ServiceUnavailable var forbiddenError *serviceErrors.Forbidden var unauthorizedError *serviceErrors.Unauthorized + var methodNotAllowedError *serviceErrors.MethodNotAllowed + var conflictError *serviceErrors.Conflict + var goneError *serviceErrors.Gone + var unsupportedMediaTypeError *serviceErrors.UnsupportedMediaType + var unprocessableEntityError *serviceErrors.UnprocessableEntity + var tooManyRequestsError *serviceErrors.TooManyRequests + var internalServerError *serviceErrors.InternalServerError + var badGatewayError *serviceErrors.BadGateway + var gatewayTimeoutError *serviceErrors.GatewayTimeout + var requestTimeoutError *serviceErrors.RequestTimeout + var notImplementedError *serviceErrors.NotImplemented err := handler(w, r) - - if (errors.As(err, ¬FoundError)) || (errors.Is(err, storage.ErrNotFound)) { - render.Status(r, http.StatusNotFound) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusNotFound), - Error: err.Error(), - } - render.JSON(w, r, response) - return - } - - if errors.As(err, &badRequestError) { - render.Status(r, http.StatusBadRequest) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusBadRequest), - Error: err.Error(), - } - render.JSON(w, r, response) - return - } - - if errors.As(err, &serviceUnavailable) { - render.Status(r, http.StatusServiceUnavailable) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusServiceUnavailable), - Error: err.Error(), - } - render.JSON(w, r, response) + if err == nil { return } - if errors.As(err, &forbiddenError) { - render.Status(r, http.StatusForbidden) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusForbidden), - Error: err.Error(), - } - render.JSON(w, r, response) - return - } + status := http.StatusInternalServerError + message := err.Error() - if errors.As(err, &unauthorizedError) { - render.Status(r, http.StatusUnauthorized) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusUnauthorized), - Error: err.Error(), - } - render.JSON(w, r, response) - return + switch { + case errors.As(err, ¬FoundError) || errors.Is(err, storage.ErrNotFound): + status = http.StatusNotFound + case errors.As(err, &badRequestError): + status = http.StatusBadRequest + case errors.As(err, &serviceUnavailable): + status = http.StatusServiceUnavailable + case errors.As(err, &forbiddenError): + status = http.StatusForbidden + case errors.As(err, &unauthorizedError): + status = http.StatusUnauthorized + case errors.As(err, &methodNotAllowedError): + status = http.StatusMethodNotAllowed + case errors.As(err, &conflictError): + status = http.StatusConflict + case errors.As(err, &goneError): + status = http.StatusGone + case errors.As(err, &unsupportedMediaTypeError): + status = http.StatusUnsupportedMediaType + case errors.As(err, &unprocessableEntityError): + status = http.StatusUnprocessableEntity + case errors.As(err, &tooManyRequestsError): + status = http.StatusTooManyRequests + case errors.As(err, &internalServerError): + status = http.StatusInternalServerError + case errors.As(err, &badGatewayError): + status = http.StatusBadGateway + case errors.As(err, &gatewayTimeoutError): + status = http.StatusGatewayTimeout + case errors.As(err, &requestTimeoutError): + status = http.StatusRequestTimeout + case errors.As(err, ¬ImplementedError): + status = http.StatusNotImplemented + default: + message = "encountered an unexpected server error: " + err.Error() } - if err != nil { - render.Status(r, http.StatusInternalServerError) - response := types.ErrorResponse{ - Status: http.StatusText(http.StatusInternalServerError), - Error: "encountered an unexpected server error: " + err.Error(), - } - render.JSON(w, r, response) - return - } + render.Status(r, status) + render.JSON(w, r, types.ErrorResponse{ + Status: http.StatusText(status), + Error: message, + }) } }