diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 0ee9831f8..32c94434c 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -35,6 +35,8 @@ pub mod client; pub mod credentials; pub mod inmemory; mod macros; +mod status; +pub use status::{ServerStatus, Status, StatusCode}; pub mod rt; pub mod server; pub mod service; diff --git a/grpc/src/status.rs b/grpc/src/status.rs new file mode 100644 index 000000000..9f40d333c --- /dev/null +++ b/grpc/src/status.rs @@ -0,0 +1,53 @@ +mod server_status; +mod status_code; + +pub use server_status::ServerStatus; +pub use status_code::StatusCode; + +/// Represents a gRPC status. +#[derive(Debug, Clone)] +pub struct Status { + code: StatusCode, + message: String, +} + +impl Status { + /// Create a new `Status` with the given code and message. + pub fn new(code: StatusCode, message: impl Into) -> Self { + Status { + code, + message: message.into(), + } + } + + /// Get the `StatusCode` of this `Status`. + pub fn code(&self) -> StatusCode { + self.code + } + + /// Get the message of this `Status`. + pub fn message(&self) -> &str { + &self.message + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_status_new() { + let status = Status::new(StatusCode::Ok, "ok"); + assert_eq!(status.code(), StatusCode::Ok); + assert_eq!(status.message(), "ok"); + } + + #[test] + fn test_status_debug() { + let status = Status::new(StatusCode::Ok, "ok"); + let debug = format!("{:?}", status); + assert!(debug.contains("Status")); + assert!(debug.contains("Ok")); + assert!(debug.contains("ok")); + } +} diff --git a/grpc/src/status/server_status.rs b/grpc/src/status/server_status.rs new file mode 100644 index 000000000..2c29d709d --- /dev/null +++ b/grpc/src/status/server_status.rs @@ -0,0 +1,66 @@ +use super::status_code::StatusCode; +use super::Status; + +/// Represents a gRPC status on the server. +/// +/// This is a separate type from `Status` to prevent accidental conversion and +/// leaking of sensitive information from the server to the client. +#[derive(Debug, Clone)] +pub struct ServerStatus(Status); + +impl std::ops::Deref for ServerStatus { + type Target = Status; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ServerStatus { + /// Create a new `ServerStatus` with the given code and message. + pub fn new(code: StatusCode, message: impl Into) -> Self { + ServerStatus(Status::new(code, message)) + } + + /// Create a new `ServerStatus` from a `Status`. + pub fn from_status(status: Status) -> Self { + ServerStatus(status) + } + + /// Converts the `ServerStatus` to a `Status` for client responses. + pub(crate) fn into_status(self) -> Status { + self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_server_status_new() { + let status = ServerStatus::new(StatusCode::Ok, "ok"); + assert_eq!(status.code(), StatusCode::Ok); + assert_eq!(status.message(), "ok"); + } + + #[test] + fn test_server_status_deref() { + let status = ServerStatus::new(StatusCode::Ok, "ok"); + assert_eq!(status.code(), StatusCode::Ok); + } + + #[test] + fn test_server_status_from_status() { + let status = Status::new(StatusCode::Ok, "ok"); + let server_status = ServerStatus::from_status(status); + assert_eq!(server_status.code(), StatusCode::Ok); + } + + #[test] + fn test_server_status_into_status() { + let server_status = ServerStatus::new(StatusCode::Ok, "ok"); + let status = server_status.into_status(); + assert_eq!(status.code(), StatusCode::Ok); + } +} diff --git a/grpc/src/status/status_code.rs b/grpc/src/status/status_code.rs new file mode 100644 index 000000000..2f5bb1d59 --- /dev/null +++ b/grpc/src/status/status_code.rs @@ -0,0 +1,22 @@ +/// Represents a gRPC status code. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(i32)] +pub enum StatusCode { + Ok = 0, + Cancelled = 1, + Unknown = 2, + InvalidArgument = 3, + DeadlineExceeded = 4, + NotFound = 5, + AlreadyExists = 6, + PermissionDenied = 7, + ResourceExhausted = 8, + FailedPrecondition = 9, + Aborted = 10, + OutOfRange = 11, + Unimplemented = 12, + Internal = 13, + Unavailable = 14, + DataLoss = 15, + Unauthenticated = 16, +}