Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 301 additions & 6 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bytes = "1"
chrono = "0.4"
clap = { version = "4", features = ["derive"] }
dotenvy = "0.15"
fjall = "3"
indexmap = "2"
reqwest = { version = "0.13", features = ["hickory-dns"] }
rustls = "0.23"
Expand All @@ -40,6 +41,7 @@ url = "2"
wasmtime = "41"
wasmtime-wasi = "41"
wasmtime-wasi-http = "41"
uuid = "1"

[profile.release]
lto = true
Expand Down
3 changes: 3 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub struct Cli {
#[arg(action=ArgAction::Set, default_value_t = true, short = 'C', long, value_name = "BOOL", help = "Enable the usage of cached plugins", long_help = None, hide_possible_values = true)]
pub cache: bool,

#[arg(default_value = "./database", short, long, value_name = "DIRECTORY PATH", help = "The path to the program its database", long_help = None)]
pub database_directory: PathBuf,

#[arg(default_value_t = 15, short = 't', long, value_name = "SECONDS", help = "The amount of seconds after which the HTTP client should timeout", long_help = None)]
pub http_client_timeout_seconds: u64,
}
Expand Down
92 changes: 92 additions & 0 deletions src/database.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* SPDX-License-Identifier: GPL-3.0-or-later */
/* Copyright © 2026 Eduard Smet */

use std::{
fs::{self},
io::ErrorKind,
path::Path,
};

use anyhow::{Result, bail};
use fjall::{Database, KeyspaceCreateOptions, PersistMode, Slice};

use crate::utils::channels::DatabaseMessages;

pub enum Keyspaces {
Plugins,
PluginStore,
DependencyFunctions,
ScheduledJobs,
DiscordEvents,
DiscordApplicationCommands,
DiscordMessageComponents,
DiscordModals,
}

pub fn new(database_directory_path: &Path) -> Result<Database> {
if let Err(err) = fs::create_dir_all(database_directory_path)
&& err.kind() != ErrorKind::AlreadyExists
{
bail!(err);
}

Ok(Database::builder(database_directory_path).open()?)
}

pub fn handle_action(database: Database, message: DatabaseMessages) {
match message {
DatabaseMessages::GetState(keyspace, key, response_sender) => {
response_sender.send(get(database, keyspace, key));
}
DatabaseMessages::InsertState(keyspace, key, value, response_sender) => {
response_sender.send(insert(database, keyspace, key, value));
}
DatabaseMessages::DeleteState(keyspace, key, response_sender) => {
response_sender.send(remove(database, keyspace, key));
}
DatabaseMessages::ContainsKey(keyspace, key, response_sender) => {
response_sender.send(contains_key(database, keyspace, key));
}
}
}

pub fn get(database: Database, keyspace: Keyspaces, key: Vec<u8>) -> Result<Option<Slice>> {
let keyspace = database.keyspace(get_keyspace(keyspace), KeyspaceCreateOptions::default)?;

Ok(keyspace.get(key)?)
}

pub fn insert(database: Database, keyspace: Keyspaces, key: Vec<u8>, value: Vec<u8>) -> Result<()> {
let keyspace = database.keyspace(get_keyspace(keyspace), KeyspaceCreateOptions::default)?;

Ok(keyspace.insert(key, value)?)
}

pub fn remove(database: Database, keyspace: Keyspaces, key: Vec<u8>) -> Result<()> {
let keyspace = database.keyspace(get_keyspace(keyspace), KeyspaceCreateOptions::default)?;

Ok(keyspace.remove(key)?)
}

pub fn contains_key(database: Database, keyspace: Keyspaces, key: Vec<u8>) -> Result<bool> {
let keyspace = database.keyspace(get_keyspace(keyspace), KeyspaceCreateOptions::default)?;

Ok(keyspace.contains_key(key)?)
}

pub fn persist(database: Database, persist_mode: PersistMode) -> Result<()> {
Ok(database.persist(persist_mode)?)
}

fn get_keyspace(keyspace: Keyspaces) -> &'static str {
match keyspace {
Keyspaces::Plugins => "plugins",
Keyspaces::PluginStore => "plugin_store",
Keyspaces::DependencyFunctions => "dependency_functions",
Keyspaces::ScheduledJobs => "scheduled_jobs",
Keyspaces::DiscordEvents => "discord_events",
Keyspaces::DiscordApplicationCommands => "discord_application_commands",
Keyspaces::DiscordMessageComponents => "discord_message_componets",
Keyspaces::DiscordModals => "discord_modals",
}
}
167 changes: 81 additions & 86 deletions src/discord.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
/* SPDX-License-Identifier: GPL-3.0-or-later */
/* Copyright © 2026 Eduard Smet */

use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;

use tokio::{
sync::{
Mutex, RwLock,
mpsc::{Receiver, Sender},
},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
task::JoinHandle,
};
use tracing::{error, info};
use twilight_cache_inmemory::{DefaultInMemoryCache, InMemoryCache, ResourceType};
use twilight_cache_inmemory::{DefaultInMemoryCache, InMemoryCache};
use twilight_gateway::{
CloseFrame, Config, Event, EventType, EventTypeFlags, Intents, MessageSender, Shard, StreamExt,
CloseFrame, Config, EventType, EventTypeFlags, Intents, MessageSender, Shard, StreamExt,
};
use twilight_http::Client;
use twilight_model::id::{Id, marker::GuildMarker};

use crate::{
SHUTDOWN,
plugins::PluginRegistrations,
utils::channels::{DiscordBotClientMessages, RuntimeMessages},
utils::channels::{CoreMessages, DiscordBotClientMessages},
};

mod events;
Expand All @@ -30,23 +25,22 @@ mod requests;

pub struct DiscordBotClient {
http_client: Arc<Client>,
shard_message_senders: Arc<RwLock<HashMap<Id<GuildMarker>, Arc<MessageSender>>>>,
shards: Vec<Shard>,
shard_message_senders: Arc<Vec<MessageSender>>,
cache: Arc<InMemoryCache>,
plugin_registrations: Arc<RwLock<PluginRegistrations>>,
runtime_tx: Arc<Sender<RuntimeMessages>>,
runtime_rx: Arc<Mutex<Receiver<DiscordBotClientMessages>>>,
core_tx: Arc<UnboundedSender<CoreMessages>>,
rx: UnboundedReceiver<DiscordBotClientMessages>,
}

impl DiscordBotClient {
pub async fn new(
token: String,
plugin_registrations: Arc<RwLock<PluginRegistrations>>,
runtime_tx: Sender<RuntimeMessages>,
runtime_rx: Receiver<DiscordBotClientMessages>,
) -> Result<(Self, Box<dyn ExactSizeIterator<Item = Shard> + Send>), ()> {
core_tx: UnboundedSender<CoreMessages>,
rx: UnboundedReceiver<DiscordBotClientMessages>,
) -> Result<Self, ()> {
info!("Creating the Discord bot client");

let intents = Intents::all();
let intents = Intents::all(); // TODO: Make this configurable

rustls::crypto::aws_lc_rs::default_provider()
.install_default()
Expand All @@ -56,14 +50,14 @@ impl DiscordBotClient {

let config = Config::new(token, intents);

let shards = match twilight_gateway::create_recommended(
let (shards, shard_message_senders) = match twilight_gateway::create_recommended(
&http_client,
config,
|_, builder| builder.build(),
)
.await
{
Ok(shards) => Box::new(shards),
Ok(shard_iterator) => Self::shard_message_senders(Box::new(shard_iterator)),
Err(err) => {
error!(
"Something went wrong while getting the recommended amount of shards from Discord, error: {}",
Expand All @@ -73,74 +67,66 @@ impl DiscordBotClient {
}
};

let shard_message_senders = Arc::new(RwLock::new(HashMap::new()));

let cache = Arc::new(
DefaultInMemoryCache::builder()
.resource_types(ResourceType::all())
.build(),
);

Ok((
DiscordBotClient {
http_client: Arc::new(http_client),
shard_message_senders,
cache,
plugin_registrations,
runtime_tx: Arc::new(runtime_tx),
runtime_rx: Arc::new(Mutex::new(runtime_rx)),
},
let cache = Arc::new(DefaultInMemoryCache::default()); // TODO: Make this configurable

Ok(DiscordBotClient {
http_client: Arc::new(http_client),
shards,
))
shard_message_senders: Arc::new(shard_message_senders),
cache,
core_tx: Arc::new(core_tx),
rx,
})
}

pub fn start(self, shards: Box<dyn ExactSizeIterator<Item = Shard> + Send>) -> JoinHandle<()> {
let mut tasks = Vec::with_capacity(shards.len());

let discord_bot_client = Arc::new(self);
pub fn start(mut self) -> JoinHandle<()> {
let mut tasks = Vec::with_capacity(self.shards.len());

for shard in shards {
for shard in self.shards.drain(..) {
tasks.push(tokio::spawn(Self::shard_runner(
discord_bot_client.clone(),
self.cache.clone(),
self.core_tx.clone(),
shard,
)));
}

tokio::spawn(async move {
while let Some(message) = discord_bot_client.runtime_rx.lock().await.recv().await {
while let Some(message) = self.rx.recv().await {
match message {
DiscordBotClientMessages::RegisterApplicationCommands(commands) => {
let _ = discord_bot_client
.application_command_registrations(commands)
.await;
DiscordBotClientMessages::RegisterApplicationCommands(
commands,
response_sender,
) => {
let http_client = self.http_client.clone();
tokio::spawn(async {
response_sender.send(
Self::application_command_registrations(http_client, commands)
.await,
);
});
}
DiscordBotClientMessages::Request(request, response_sender) => {
let _ = response_sender.send(discord_bot_client.request(request).await);
}
DiscordBotClientMessages::Shutdown(is_done) => {
for sender in discord_bot_client
.shard_message_senders
.read()
.await
.values()
{
_ = sender.close(CloseFrame::NORMAL);
}

for task in tasks.drain(..) {
let _ = task.await;
}

let _ = is_done.send(());
let http_client = self.http_client.clone();
let shard_message_senders = self.shard_message_senders.clone();

tokio::spawn(async {
response_sender.send(
Self::request(http_client, shard_message_senders, request).await,
);
});
}
}
}

self.shutdown(tasks);
})
}

pub async fn shard_runner(discord_bot_client: Arc<DiscordBotClient>, mut shard: Shard) {
let shard_message_sender = Arc::new(shard.sender());

async fn shard_runner(
cache: Arc<InMemoryCache>,
core_tx: Arc<UnboundedSender<CoreMessages>>,
mut shard: Shard,
) {
while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
let Ok(event) = item else {
error!(
Expand All @@ -155,24 +141,33 @@ impl DiscordBotClient {
break;
}

discord_bot_client.cache.update(&event);
cache.update(&event);

match event {
Event::Ready(ready) => {
info!("Shard is ready, logged in as {}", &ready.user.name);
tokio::spawn(Self::handle_event(core_tx.clone(), event));
}
}

for guild in ready.guilds {
discord_bot_client
.shard_message_senders
.write()
.await
.insert(guild.id, shard_message_sender.clone());
}
}
_ => {
tokio::spawn(Self::handle_event(discord_bot_client.clone(), event));
}
}
fn shard_message_senders(
shard_iterator: Box<dyn ExactSizeIterator<Item = Shard>>,
) -> (Vec<Shard>, Vec<MessageSender>) {
let mut shards = vec![];
let mut shard_message_senders = vec![];

for shard in shard_iterator {
shard_message_senders.push(shard.sender());
shards.push(shard);
}

(shards, shard_message_senders)
}

async fn shutdown(&self, mut tasks: Vec<JoinHandle<()>>) {
for shard_message_sender in self.shard_message_senders.iter() {
_ = shard_message_sender.close(CloseFrame::NORMAL);
}

for task in tasks.drain(..) {
let _ = task.await;
}
}
}
Loading
Loading