diff --git a/Cargo.lock b/Cargo.lock index 6608bea..5efbd6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", + "base64 0.22.1", "bytes", "form_urlencoded", "futures-util", @@ -113,8 +114,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -2482,7 +2485,7 @@ dependencies = [ "static_assertions", "time", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.21.0", "tracing", "typemap_rev", "typesize", @@ -3170,10 +3173,22 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls 0.25.0", - "tungstenite", + "tungstenite 0.21.0", "webpki-roots 0.26.11", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.28.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -3338,6 +3353,23 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "typemap_rev" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 06794ee..1a41a9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] anyhow = "1.0.101" async-trait = "0.1" -axum = "0.8.8" +axum = { version = "0.8.8", features = ["ws"] } chrono = { version = "0.4", features = ["serde"] } dotenv = "0.15.0" poise = "0.6.1" diff --git a/README.md b/README.md index 2c2b82b..6325a05 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,20 @@ https://discord.com/oauth2/authorize?client_id=1367651749262921868&permissions=277025458240&integration_type=0&scope=bot +## Create your own + +You can easily create your own version of this bot. + +It needs the premissions in 0Auth2: + +- Bot +- Create commands +- Send messages +- Send messages in Threads +- Read Message History +- Add Reactions +- Use Appliation Commands + ## How to run Make sure that you have copied `.env.example` to `.env` and filled in the missing enviornment variables. diff --git a/justfile b/justfile index 1e4b886..3e0e737 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,16 @@ set dotenv-load +dev: + cargo run + +up: + docker compose up redis postgres -d + +down: + docker compose down + +test: + cargo test -- --test-threads=1 + migrate: sqlx migrate run --source migrations --database-url "$DATABASE_URL" diff --git a/src/adapters/http/mod.rs b/src/adapters/http/mod.rs index 0a30446..32b8d95 100644 --- a/src/adapters/http/mod.rs +++ b/src/adapters/http/mod.rs @@ -1,7 +1,11 @@ use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + Path, State, + }, + response::Response, + routing::{any, get}, Json, Router, - extract::{Path, State}, - routing::get, }; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; @@ -9,7 +13,7 @@ use tracing::info; use std::{io, sync::Arc}; -use crate::domain::{OrderRepository, QueueEntry, QueueRepository}; +use crate::domain::{OrderRepository, QueueEntry, QueueEvent, QueueRepository}; #[derive(Clone)] pub struct AppState { @@ -35,7 +39,9 @@ impl HttpAdapter { }); let app = Router::new() + .route("/{guild_id}/status", get(queue_status)) .route("/{guild_id}/queue", get(list_queue)) + .route("/{guild_id}/queue/ws", any(list_queue_ws)) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) .with_state(state); @@ -46,6 +52,12 @@ impl HttpAdapter { } } +async fn queue_status(State(state): State>, Path(guild_id): Path) -> String { + let is_open = state.queue.is_open(&guild_id); + let status = if is_open { "open" } else { "closed" }; + status.to_string() +} + async fn list_queue( State(state): State>, Path(guild_id): Path, @@ -53,3 +65,38 @@ async fn list_queue( let queue = state.queue.list(&guild_id).await; Json(queue) } + +async fn list_queue_ws( + State(state): State>, + Path(guild_id): Path, + ws: WebSocketUpgrade, +) -> Response { + ws.on_upgrade(move |socket| list_queue_ws_handler(state, guild_id, socket)) +} + +async fn list_queue_ws_handler(state: Arc, guild_id: String, mut socket: WebSocket) { + info!("new websocket connection for guild_id: {}", guild_id); + + let mut rx = state.queue.subscribe(); + + // Initial state + let queue = state.queue.list(&guild_id).await; + let msg = serde_json::to_string(&queue).unwrap(); + if socket.send(Message::Text(msg.into())).await.is_err() { + return; + } + + while let Ok(event) = rx.recv().await { + match event { + QueueEvent::Updated { guild_id: gid } => { + if gid == guild_id { + let queue = state.queue.list(&guild_id).await; + let msg = serde_json::to_string(&queue).unwrap(); + if socket.send(Message::Text(msg.into())).await.is_err() { + break; + } + } + } + } + } +} diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 5508f99..0128bfe 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -2,4 +2,4 @@ pub mod order; pub mod queue; pub use order::{DailyStats, OrderRepository}; -pub use queue::{QueueEntry, QueueRepository}; +pub use queue::{QueueEntry, QueueEvent, QueueRepository}; diff --git a/src/domain/queue.rs b/src/domain/queue.rs index e9f18a8..e098e3f 100644 --- a/src/domain/queue.rs +++ b/src/domain/queue.rs @@ -13,6 +13,11 @@ impl QueueEntry { } } +#[derive(Debug, Clone)] +pub enum QueueEvent { + Updated { guild_id: String }, +} + #[async_trait::async_trait] pub trait QueueRepository: Send + Sync { /// Open the queue to allow new entries @@ -47,4 +52,7 @@ pub trait QueueRepository: Send + Sync { /// Clear the queue async fn clear(&self, guild_id: &str); + + /// Subscribe to queue change events + fn subscribe(&self) -> tokio::sync::broadcast::Receiver; } diff --git a/src/infrastructure/redis_queue_repository.rs b/src/infrastructure/redis_queue_repository.rs index 1267ffc..1f9a4ff 100644 --- a/src/infrastructure/redis_queue_repository.rs +++ b/src/infrastructure/redis_queue_repository.rs @@ -1,9 +1,10 @@ use std::{collections::HashSet, sync::RwLock}; use redis::AsyncCommands; +use tokio::sync::broadcast; use tracing::{debug, error, info, instrument, warn}; -use crate::domain::{QueueEntry, QueueRepository}; +use crate::domain::{QueueEntry, QueueEvent, QueueRepository}; fn queue_key(guild_id: &str) -> String { format!("queue:{guild_id}") @@ -12,13 +13,16 @@ fn queue_key(guild_id: &str) -> String { pub struct RedisQueueRepository { redis: redis::Client, open_guilds: RwLock>, + event_tx: broadcast::Sender, } impl RedisQueueRepository { pub fn new(redis: redis::Client) -> Self { + let (event_tx, _) = broadcast::channel(64); Self { redis, open_guilds: RwLock::new(HashSet::new()), + event_tx, } } } @@ -39,6 +43,7 @@ impl QueueRepository for RedisQueueRepository { info!(guild_id, "Closing queue for guild"); self.open_guilds.write().unwrap().remove(guild_id); self.clear(guild_id).await; + self.broadcast_update(guild_id); } #[instrument(skip(self), fields(guild_id))] @@ -111,6 +116,7 @@ impl QueueRepository for RedisQueueRepository { 0 }); info!(guild_id, user_id = %entry.user_id, queue_size = new_size, "Added user to queue"); + self.broadcast_update(guild_id); new_size } @@ -134,6 +140,9 @@ impl QueueRepository for RedisQueueRepository { Some(e) => info!(guild_id, user_id = %e.user_id, "Popped user from queue"), None => debug!(guild_id, "No entry to pop from queue"), } + if entry.is_some() { + self.broadcast_update(guild_id); + } entry } @@ -159,6 +168,9 @@ impl QueueRepository for RedisQueueRepository { .filter_map(|json_str| serde_json::from_str(&json_str).ok()) .collect(); info!(guild_id, count = entries.len(), "Popped entries from queue"); + if !entries.is_empty() { + self.broadcast_update(guild_id); + } entries } @@ -201,6 +213,26 @@ impl QueueRepository for RedisQueueRepository { } else { error!(guild_id, "Failed to get Redis connection for clear"); } + self.broadcast_update(guild_id); + } + + fn subscribe(&self) -> tokio::sync::broadcast::Receiver { + self.event_tx.subscribe() + } +} + +impl RedisQueueRepository { + fn broadcast_update(&self, guild_id: &str) { + if let Err(err) = self.event_tx.send(QueueEvent::Updated { + guild_id: guild_id.to_string(), + }) { + error!( + guild_id, + "error" = ?err, + "guild_id" = guild_id, + "Failed to broadcast queue update event" + ); + } } } diff --git a/src/main.rs b/src/main.rs index 28f3f5b..61fb927 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::env; use tracing::{error, info}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use vaffelbot_rs::{VaffelBot, config::Config}; +use vaffelbot_rs::{config::Config, VaffelBot}; #[tokio::main] async fn main() {