diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f3ccdb64..64a9d2a5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,6 +54,7 @@ jobs: mingw-w64-x86_64-libhandy mingw-w64-x86_64-opus mingw-w64-x86_64-libsodium + mingw-w64-x86_64-openssl mingw-w64-x86_64-spdlog - name: Setup MSYS2 (2) @@ -151,10 +152,11 @@ jobs: brew install libsodium brew install spdlog brew install libhandy + brew install openssl@3 - name: Build run: | - cmake -Bbuild -DCMAKE_BUILD_TYPE=${{ matrix.buildtype }} + cmake -Bbuild -DCMAKE_BUILD_TYPE=${{ matrix.buildtype }} -DOPENSSL_ROOT_DIR=$(brew --prefix openssl@3) cmake --build build - name: Setup artifact files @@ -188,22 +190,15 @@ jobs: - name: Fetch dependencies run: | sudo apt-get update - mkdir deps - cd deps + sudo apt-get install -y libgtkmm-3.0-dev libcurl4-gnutls-dev libopus-dev libsodium-dev libspdlog-dev libhandy-1-dev libssl-dev + mkdir deps && cd deps git clone https://github.com/nlohmann/json cd json git checkout 55f93686c01528224f448c19128836e7df245f72 - mkdir build - cd build + mkdir build && cd build cmake .. -DJSON_BuildTests=OFF make sudo make install - sudo apt-get install libgtkmm-3.0-dev - sudo apt-get install libcurl4-gnutls-dev - sudo apt-get install libopus-dev - sudo apt-get install libsodium-dev - sudo apt-get install libspdlog-dev - sudo apt-get install libhandy-1-dev - name: Build run: | diff --git a/.gitmodules b/.gitmodules index e0d062ad..870486de 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,9 @@ [submodule "subprojects/qrcodegen"] path = subprojects/qrcodegen url = https://github.com/nayuki/QR-Code-generator +[submodule "subprojects/libdave"] + path = subprojects/libdave + url = https://github.com/discord/libdave.git +[submodule "subprojects/mlspp"] + path = subprojects/mlspp + url = https://github.com/cisco/mlspp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 8fb2070e..2c8e8d99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -168,6 +168,60 @@ if (ENABLE_VOICE) target_link_libraries(abaddon ${CMAKE_DL_LIBS}) + # mlspp and libdave need nlohmann_json::nlohmann_json target + if (NOT TARGET nlohmann_json::nlohmann_json) + add_library(nlohmann_json::nlohmann_json INTERFACE IMPORTED) + set_target_properties(nlohmann_json::nlohmann_json PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${NLOHMANN_JSON_INCLUDE_DIRS}") + endif () + + find_package(MLSPP QUIET) + if (NOT MLSPP_FOUND) + message(STATUS "mlspp not found, using subproject") + set(DISABLE_GREASE ON CACHE BOOL "" FORCE) + set(TESTING OFF CACHE BOOL "" FORCE) + set(MLS_CXX_NAMESPACE "mlspp" CACHE STRING "" FORCE) + set(CMAKE_EXPORT_PACKAGE_REGISTRY OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "OpenBSD") + message(STATUS "Forcing mlspp to use BoringSSL on OpenBSD") + set(REQUIRE_BORINGSSL ON CACHE BOOL "" FORCE) + set(OPENSSL_ROOT_DIR "/usr/local/eboringssl" CACHE PATH "" FORCE) + set(OPENSSL_INCLUDE_DIR "/usr/local/eboringssl/include" CACHE PATH "" FORCE) + set(OPENSSL_CRYPTO_LIBRARY "/usr/local/eboringssl/lib/libcrypto.a" CACHE FILEPATH "" FORCE) + set(OPENSSL_SSL_LIBRARY "/usr/local/eboringssl/lib/libssl.a" CACHE FILEPATH "" FORCE) + endif() + add_subdirectory(subprojects/mlspp EXCLUDE_FROM_ALL) + if (NOT TARGET MLSPP::mlspp) + add_library(MLSPP::mlspp ALIAS mlspp) + endif () + if (NOT TARGET MLSPP::hpke) + add_library(MLSPP::hpke ALIAS hpke) + endif () + set(MLSPP_FOUND TRUE CACHE BOOL "" FORCE) + # Write a shim config so libdave's find_dependency(MLSPP) finds this + # instead of the broken build-tree config that add_subdirectory generates + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/mlspp-shim/MLSPPConfig.cmake" + "set(MLSPP_FOUND TRUE)\n") + set(MLSPP_DIR "${CMAKE_CURRENT_BINARY_DIR}/mlspp-shim" CACHE PATH "" FORCE) + else () + message(STATUS "Found system mlspp") + endif () + + # Prevent find_package from finding mlspp's broken build-tree config via registry + set(CMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY ON) + + find_package(libdave QUIET) + if (libdave_FOUND) + message(STATUS "Found system libdave") + target_link_libraries(abaddon libdave) + else () + message(STATUS "libdave not found, using subproject") + add_subdirectory(subprojects/libdave/cpp EXCLUDE_FROM_ALL) + target_link_libraries(abaddon libdave) + endif () + + set(CMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY OFF) + if (ENABLE_RNNOISE) target_compile_definitions(abaddon PRIVATE WITH_RNNOISE) diff --git a/README.md b/README.md index f8e48f5f..1d4c21a7 100644 --- a/README.md +++ b/README.md @@ -68,12 +68,13 @@ the result of fundamental issues with Discord's thread implementation. * mingw-w64-x86_64-libhandy * mingw-w64-x86_64-opus * mingw-w64-x86_64-libsodium + * mingw-w64-x86_64-openssl * mingw-w64-x86_64-spdlog 2. `git clone --recurse-submodules="subprojects" https://github.com/uowuo/abaddon && cd abaddon` -3. `mkdir build && cd build` -4. `cmake -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo ..` -5. `ninja` -6. [Copy resources](#resources) +4. `mkdir build && cd build` +5. `cmake -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo ..` +6. `ninja` +7. [Copy resources](#resources) #### Mac: @@ -185,6 +186,8 @@ spam filter's wrath: * [libopus](https://opus-codec.org/) (optional, required for voice) * [libsodium](https://doc.libsodium.org/) (optional, required for voice) * [rnnoise](https://gitlab.xiph.org/xiph/rnnoise) (optional, provided as submodule, noise suppression and improved VAD) +* [libdave](https://github.com/discord/libdave) (provided as submodule, required for voice E2EE) +* [mlspp](https://github.com/cisco/mlspp) (provided as submodule, required for voice E2EE) ### TODO: diff --git a/ci/msys-deps.txt b/ci/msys-deps.txt index 6f96c5d3..5b96a868 100644 --- a/ci/msys-deps.txt +++ b/ci/msys-deps.txt @@ -59,7 +59,7 @@ /bin/libsharpyuv-0.dll /bin/libsigc-2.0-0.dll /bin/libsodium-26.dll -/bin/libspdlog-1.15.dll +/bin/libspdlog-1.17.dll /bin/libsqlite3-0.dll /bin/libssh2-1.dll /bin/libssl-3-x64.dll diff --git a/src/discord/dave.cpp b/src/discord/dave.cpp new file mode 100644 index 00000000..15be34b6 --- /dev/null +++ b/src/discord/dave.cpp @@ -0,0 +1,320 @@ +#ifdef WITH_VOICE +// clang-format off + +#include "dave.hpp" +#include "voiceclient.hpp" +#include +#include +// clang-format on + +static bool s_dave_log_sink_set = false; + +DaveSession::DaveSession(Snowflake channelId, Snowflake userId, + const std::unordered_map &ssrcUserMap) + : m_channel_id(channelId) + , m_user_id(userId) + , m_ssrc_user_map(ssrcUserMap) + , m_log(spdlog::get("voice")) { + if (!s_dave_log_sink_set) { + s_dave_log_sink_set = true; + discord::dave::SetLogSink([](discord::dave::LoggingSeverity severity, + const char *file, int line, + const std::string &message) { + auto log = spdlog::get("voice"); + if (!log) return; + switch (severity) { + case discord::dave::LS_ERROR: + log->error("[DAVE] {}:{} {}", file, line, message); + break; + case discord::dave::LS_WARNING: + log->warn("[DAVE] {}:{} {}", file, line, message); + break; + case discord::dave::LS_INFO: + log->debug("[DAVE] {}", message); + break; + case discord::dave::LS_VERBOSE: + log->debug("[DAVE] {}", message); + break; + default: + break; + } + }); + } + + m_mls_session = discord::dave::mls::CreateSession( + nullptr, "", + [this](const std::string &reason, const std::string &detail) { + m_log->warn("MLS failure: {} {}", reason, detail); + }); +} + +DaveSession::~DaveSession() = default; + +void DaveSession::Init(uint16_t version) { + m_protocol_version = version; + m_pending_protocol_version = version; + Reinit(); +} + +void DaveSession::Reinit() { + m_enabled = false; + m_downgraded = false; + + auto self_id = std::to_string(static_cast(m_user_id)); + m_connected_users.insert(self_id); + + m_log->info("Initializing DAVE session: channel={} version={} users={}", + static_cast(m_channel_id), m_protocol_version, m_connected_users.size()); + + m_mls_session->Init( + m_protocol_version, + static_cast(m_channel_id), + self_id, + m_transient_key); + + m_decryptors.clear(); + m_pending_transition_ready = false; + + m_encryptor = discord::dave::CreateEncryptor(); + if (m_local_ssrc != 0) + m_encryptor->AssignSsrcToCodec(m_local_ssrc, discord::dave::Codec::Opus); + + auto keyPackage = m_mls_session->GetMarshalledKeyPackage(); + if (!keyPackage.empty()) { + m_log->info("Sending MLS key package, size={}", keyPackage.size()); + m_signal_send_binary.emit(static_cast(VoiceGatewayOp::MlsKeyPackage), keyPackage); + } +} + +void DaveSession::OnExternalSenderPackage(const uint8_t *data, size_t size) { + m_log->info("Received external sender package, size={}", size); + std::vector payload(data, data + size); + m_mls_session->SetExternalSender(payload); +} + +void DaveSession::OnProposals(const uint8_t *data, size_t size) { + m_log->info("Received proposals, size={} connectedUsers={}", size, m_connected_users.size()); + std::vector payload(data, data + size); + auto response = m_mls_session->ProcessProposals(std::move(payload), m_connected_users); + + if (response) { + m_log->info("Sending commit+welcome, size={}", response->size()); + m_signal_send_binary.emit(static_cast(VoiceGatewayOp::MlsCommitWelcome), *response); + } +} + +void DaveSession::OnAnnounceCommitTransition(const uint8_t *data, size_t size) { + if (size < 2) return; + + int transitionId = (data[0] << 8) | data[1]; + m_pending_transition_id = transitionId; + + m_log->debug("Received announce commit transition: transitionId={} size={}", transitionId, size); + + std::vector payload(data + 2, data + size); + auto result = m_mls_session->ProcessCommit(std::move(payload)); + + if (auto *roster = std::get_if(&result)) { + m_log->info("ProcessCommit succeeded, roster size={}", roster->size()); + m_pending_transition_ready = true; + m_signal_send_ready.emit(transitionId); + if (transitionId == 0) + CompleteTransition(); + } else if (std::holds_alternative(result)) { + m_log->warn("ProcessCommit failed (hard reject)"); + m_signal_send_invalid.emit(transitionId); + } else { + m_log->debug("ProcessCommit ignored (soft reject)"); + } +} + +void DaveSession::OnWelcome(const uint8_t *data, size_t size) { + if (size < 2) return; + + int transitionId = (data[0] << 8) | data[1]; + m_pending_transition_id = transitionId; + + m_log->info("Received welcome: transitionId={} size={} connectedUsers={}", + transitionId, size, m_connected_users.size()); + + std::vector payload(data + 2, data + size); + auto roster = m_mls_session->ProcessWelcome(std::move(payload), m_connected_users); + + if (roster) { + m_log->info("ProcessWelcome succeeded, roster size={}", roster->size()); + m_pending_transition_ready = true; + m_signal_send_ready.emit(transitionId); + if (transitionId == 0) + CompleteTransition(); + } else { + m_log->warn("ProcessWelcome failed"); + m_signal_send_invalid.emit(transitionId); + } +} + +void DaveSession::OnPrepareTransition(int version, int transitionId) { + m_log->info("Prepare transition: version={} transitionId={}", version, transitionId); + m_pending_transition_id = transitionId; + m_pending_protocol_version = static_cast(version); + m_signal_send_ready.emit(transitionId); +} + +void DaveSession::OnExecuteTransition(int transitionId) { + m_log->info("Execute transition: transitionId={}", transitionId); + + if (m_pending_protocol_version != m_protocol_version) { + m_protocol_version = m_pending_protocol_version; + if (m_protocol_version == 0) { + m_log->info("DAVE downgrade to version 0, disabling"); + m_enabled = false; + m_downgraded = true; + m_signal_state_changed.emit(false); + return; + } + } + + if (!m_pending_transition_ready) { + m_log->warn("Execute transition {} but no pending commit/welcome, reinitializing", transitionId); + Reinit(); + return; + } + + CompleteTransition(); +} + +void DaveSession::CompleteTransition() { + if (!m_pending_transition_ready) { + m_log->warn("completeTransition but no pending commit/welcome!"); + return; + } + m_pending_transition_ready = false; + + auto selfRatchet = m_mls_session->GetKeyRatchet(std::to_string(static_cast(m_user_id))); + if (selfRatchet) { + m_encryptor->SetKeyRatchet(std::move(selfRatchet)); + m_log->info("Refreshed encryptor key ratchet for new epoch"); + } else { + m_log->warn("Could not get own key ratchet from MLS session"); + } + + for (auto &[ssrc, dec] : m_decryptors) { + auto it = m_ssrc_user_map.find(ssrc); + if (it == m_ssrc_user_map.end()) + continue; + auto ratchet = m_mls_session->GetKeyRatchet(std::to_string(static_cast(it->second))); + if (ratchet) { + dec->TransitionToKeyRatchet(std::move(ratchet)); + m_log->debug("Refreshed decryptor key ratchet for SSRC={}", ssrc); + } + } + + if (!m_enabled) { + m_enabled = true; + m_downgraded = false; + m_log->info("DAVE encryption enabled"); + } + + m_signal_state_changed.emit(true); + m_pending_transition_id = -1; +} + +void DaveSession::OnPrepareEpoch(int version, int epoch) { + m_log->info("Prepare epoch: version={} epoch={}", version, epoch); + if (epoch == 1) { + m_protocol_version = static_cast(version); + Reinit(); + } +} + +discord::dave::IEncryptor *DaveSession::GetEncryptor() { + return m_encryptor.get(); +} + +discord::dave::IDecryptor *DaveSession::GetOrCreateDecryptor(uint32_t ssrc, Snowflake uid) { + auto it = m_decryptors.find(ssrc); + if (it != m_decryptors.end()) + return it->second.get(); + + auto decryptor = discord::dave::CreateDecryptor(); + + bool hasRatchet = false; + if (static_cast(uid) != 0) { + auto keyRatchet = m_mls_session->GetKeyRatchet(std::to_string(static_cast(uid))); + if (keyRatchet) { + decryptor->TransitionToKeyRatchet(std::move(keyRatchet)); + hasRatchet = true; + } + } + + auto *ptr = decryptor.get(); + m_decryptors.emplace(ssrc, std::move(decryptor)); + + m_log->debug("Created decryptor for SSRC={} hasRatchet={}", ssrc, hasRatchet); + return ptr; +} + +void DaveSession::SetLocalSSRC(uint32_t ssrc) { + m_local_ssrc = ssrc; + if (m_encryptor) + m_encryptor->AssignSsrcToCodec(ssrc, discord::dave::Codec::Opus); +} + +void DaveSession::AddConnectedUser(const std::string &id) { + m_connected_users.insert(id); +} + +void DaveSession::RemoveConnectedUser(const std::string &id) { + m_connected_users.erase(id); + + uint64_t uid = std::stoull(id); + for (auto it = m_ssrc_user_map.begin(); it != m_ssrc_user_map.end(); ++it) { + if (static_cast(it->second) == uid) { + m_decryptors.erase(it->first); + break; + } + } +} + +void DaveSession::GetPairwiseFingerprint(const std::string &userId, FingerprintCallback callback) const { + if (!m_mls_session || !m_enabled) { + callback({}); + return; + } + m_mls_session->GetPairwiseFingerprint(0, userId, std::move(callback)); +} + +std::vector DaveSession::GetLastEpochAuthenticator() const { + if (!m_mls_session || !m_enabled) + return {}; + return m_mls_session->GetLastEpochAuthenticator(); +} + +void DaveSession::ApplyKeyRatchetForSSRC(uint32_t ssrc, Snowflake uid) { + auto it = m_decryptors.find(ssrc); + if (it == m_decryptors.end()) + return; + + auto keyRatchet = m_mls_session->GetKeyRatchet(std::to_string(static_cast(uid))); + if (keyRatchet) { + it->second->TransitionToKeyRatchet(std::move(keyRatchet)); + m_log->info("Applied key ratchet for SSRC={} user={}", ssrc, static_cast(uid)); + } +} + +DaveSession::type_signal_send_binary DaveSession::signal_send_binary() { + return m_signal_send_binary; +} + +DaveSession::type_signal_send_ready DaveSession::signal_send_ready_for_transition() { + return m_signal_send_ready; +} + +DaveSession::type_signal_send_invalid DaveSession::signal_send_invalid_commit_welcome() { + return m_signal_send_invalid; +} + +DaveSession::type_signal_state_changed DaveSession::signal_dave_state_changed() { + return m_signal_state_changed; +} + +#endif diff --git a/src/discord/dave.hpp b/src/discord/dave.hpp new file mode 100644 index 00000000..f8cd0cf9 --- /dev/null +++ b/src/discord/dave.hpp @@ -0,0 +1,98 @@ +#pragma once +#ifdef WITH_VOICE +// clang-format off + +#include "snowflake.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on + +namespace mlspp { +struct SignaturePrivateKey; +} + +class DaveSession { +public: + DaveSession(Snowflake channelId, Snowflake userId, + const std::unordered_map &ssrcUserMap); + ~DaveSession(); + + void Init(uint16_t protocolVersion); + + // binary payloads from server + void OnExternalSenderPackage(const uint8_t *data, size_t size); + void OnProposals(const uint8_t *data, size_t size); + void OnAnnounceCommitTransition(const uint8_t *data, size_t size); + void OnWelcome(const uint8_t *data, size_t size); + + // json payloads from server + void OnPrepareTransition(int protocolVersion, int transitionId); + void OnExecuteTransition(int transitionId); + void OnPrepareEpoch(int protocolVersion, int epoch); + + discord::dave::IEncryptor *GetEncryptor(); + discord::dave::IDecryptor *GetOrCreateDecryptor(uint32_t ssrc, Snowflake userId); + + bool IsEnabled() const { return m_enabled; } + bool IsDowngraded() const { return m_downgraded; } + + using FingerprintCallback = std::function &)>; + void GetPairwiseFingerprint(const std::string &userId, FingerprintCallback callback) const; + std::vector GetLastEpochAuthenticator() const; + + void SetLocalSSRC(uint32_t ssrc); + void AddConnectedUser(const std::string &userId); + void RemoveConnectedUser(const std::string &userId); + void ApplyKeyRatchetForSSRC(uint32_t ssrc, Snowflake userId); + + // signal types + using type_signal_send_binary = sigc::signal &data)>; + using type_signal_send_ready = sigc::signal; + using type_signal_send_invalid = sigc::signal; + using type_signal_state_changed = sigc::signal; + + type_signal_send_binary signal_send_binary(); + type_signal_send_ready signal_send_ready_for_transition(); + type_signal_send_invalid signal_send_invalid_commit_welcome(); + type_signal_state_changed signal_dave_state_changed(); + +private: + void Reinit(); + void CompleteTransition(); + + std::unique_ptr m_mls_session; + std::unique_ptr m_encryptor; + std::unordered_map> m_decryptors; + + uint16_t m_protocol_version = 0; + uint16_t m_pending_protocol_version = 0; + Snowflake m_channel_id; + Snowflake m_user_id; + uint32_t m_local_ssrc = 0; + + std::set m_connected_users; + int m_pending_transition_id = -1; + bool m_enabled = false; + bool m_downgraded = false; + bool m_pending_transition_ready = false; + + const std::unordered_map &m_ssrc_user_map; + + std::shared_ptr<::mlspp::SignaturePrivateKey> m_transient_key; + + std::shared_ptr m_log; + + type_signal_send_binary m_signal_send_binary; + type_signal_send_ready m_signal_send_ready; + type_signal_send_invalid m_signal_send_invalid; + type_signal_state_changed m_signal_state_changed; +}; + +#endif diff --git a/src/discord/discord.cpp b/src/discord/discord.cpp index 34600ea3..b5ddfbc3 100644 --- a/src/discord/discord.cpp +++ b/src/discord/discord.cpp @@ -2476,6 +2476,7 @@ void DiscordClient::HandleGatewayVoiceServerUpdate(const GatewayMessage &msg) { spdlog::get("discord")->error("No guild or channel ID in voice server?"); } m_voice.SetUserID(m_user_data.ID); + m_voice.SetChannelID(m_voice_channel_id); m_voice.Start(); } diff --git a/src/discord/voiceclient.cpp b/src/discord/voiceclient.cpp index f8505239..cca297dc 100644 --- a/src/discord/voiceclient.cpp +++ b/src/discord/voiceclient.cpp @@ -8,6 +8,8 @@ #include #include "abaddon.hpp" #include "audio/manager.hpp" +#include +#include #ifdef _WIN32 #define S_ADDR(var) (var).sin_addr.S_un.S_addr @@ -157,18 +159,40 @@ DiscordVoiceClient::DiscordVoiceClient() OnUDPData(data); }); + m_ws.SetSeparateBinaryMessages(true); m_ws.signal_open().connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnWebsocketOpen)); m_ws.signal_close().connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnWebsocketClose)); m_ws.signal_message().connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnWebsocketMessage)); + m_ws.signal_binary_message().connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnWebsocketBinaryMessage)); m_dispatcher.connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnDispatch)); + m_binary_dispatcher.connect(sigc::mem_fun(*this, &DiscordVoiceClient::OnBinaryDispatch)); // idle or else singleton deadlock Glib::signal_idle().connect_once([this]() { auto &audio = Abaddon::Get().GetAudio(); audio.SetOpusBuffer(m_opus_buffer.data()); audio.signal_opus_packet().connect([this](int payload_size) { - if (IsConnected()) { + if (!IsConnected()) return; + + if (m_dave && m_dave->IsEnabled()) { + auto *enc = m_dave->GetEncryptor(); + if (!enc) return; + auto max_size = enc->GetMaxCiphertextByteSize(discord::dave::MediaType::Audio, payload_size); + std::vector dave_encrypted(max_size); + size_t bytes_written = 0; + auto result = enc->Encrypt( + discord::dave::MediaType::Audio, + m_ssrc, + discord::dave::MakeArrayView(const_cast(m_opus_buffer.data()), static_cast(payload_size)), + discord::dave::MakeArrayView(dave_encrypted.data(), dave_encrypted.size()), + &bytes_written); + if (result == discord::dave::IEncryptor::Success) { + m_udp.SendEncrypted(dave_encrypted.data(), bytes_written); + } else { + m_log->warn("DAVE encrypt failed: result={}", static_cast(result)); + } + } else { m_udp.SendEncrypted(m_opus_buffer.data(), payload_size); } }); @@ -186,9 +210,12 @@ void DiscordVoiceClient::Start() { SetState(State::ConnectingToWebsocket); m_ssrc_map.clear(); + m_ssrc_user_map.clear(); + m_connected_users.clear(); + m_dave.reset(); m_heartbeat_waiter.revive(); m_keepalive_waiter.revive(); - m_ws.StartConnection("wss://" + m_endpoint + "/?v=7"); + m_ws.StartConnection("wss://" + m_endpoint + "/?v=9"); m_signal_connected.emit(); } @@ -209,6 +236,9 @@ void DiscordVoiceClient::Stop() { if (m_keepalive_thread.joinable()) m_keepalive_thread.join(); m_ssrc_map.clear(); + m_ssrc_user_map.clear(); + m_connected_users.clear(); + m_dave.reset(); m_signal_disconnected.emit(); } @@ -229,6 +259,10 @@ void DiscordVoiceClient::SetServerID(Snowflake id) { m_server_id = id; } +void DiscordVoiceClient::SetChannelID(Snowflake id) { + m_channel_id = id; +} + void DiscordVoiceClient::SetUserID(Snowflake id) { m_user_id = id; } @@ -264,7 +298,10 @@ bool DiscordVoiceClient::IsConnecting() const noexcept { void DiscordVoiceClient::OnGatewayMessage(const std::string &str) { m_log->trace("IN: {}", str); - VoiceGatewayMessage msg = nlohmann::json::parse(str); + auto j = nlohmann::json::parse(str); + if (j.contains("seq") && !j["seq"].is_null()) + m_last_received_seq = j["seq"].get(); + VoiceGatewayMessage msg = j; switch (msg.Opcode) { case VoiceGatewayOp::Hello: HandleGatewayHello(msg); @@ -280,6 +317,24 @@ void DiscordVoiceClient::OnGatewayMessage(const std::string &str) { break; case VoiceGatewayOp::HeartbeatAck: break; // stfu + case VoiceGatewayOp::SecureFramesPrepareProtocolTransition: + HandleGatewayDavePrepareTransition(msg); + break; + case VoiceGatewayOp::SecureFramesExecuteTransition: + HandleGatewayDaveExecuteTransition(msg); + break; + case VoiceGatewayOp::SecureFramesPrepareEpoch: + HandleGatewayDavePrepareEpoch(msg); + break; + case VoiceGatewayOp::ClientConnect: + HandleGatewayClientConnect(msg); + break; + case VoiceGatewayOp::ClientDisconnect: + HandleGatewayClientDisconnect(msg); + break; + case VoiceGatewayOp::SessionUpdate: + HandleGatewaySessionUpdate(msg); + break; default: const auto opcode_int = static_cast(msg.Opcode); m_log->warn("Unhandled opcode: {}", opcode_int); @@ -339,15 +394,20 @@ void DiscordVoiceClient::HandleGatewaySessionDescription(const VoiceGatewayMessa const auto key_hex = spdlog::to_hex(d.SecretKey.begin(), d.SecretKey.end()); m_log->debug("Received session description (mode: {}) (key: {:ns}) ", d.Mode, key_hex); + m_secret_key = d.SecretKey; + m_udp.SetSSRC(m_ssrc); + m_udp.SetSecretKey(m_secret_key); + + m_log->info("dave_protocol_version={}", d.DaveProtocolVersion); + if (d.DaveProtocolVersion > 0) + EnsureDaveSession(static_cast(d.DaveProtocolVersion)); + VoiceSpeakingMessage msg; msg.Delay = 0; msg.SSRC = m_ssrc; msg.Speaking = VoiceSpeakingType::Microphone; m_ws.Send(msg); - m_secret_key = d.SecretKey; - m_udp.SetSSRC(m_ssrc); - m_udp.SetSecretKey(m_secret_key); m_udp.SendEncrypted({ 0xF8, 0xFF, 0xFE }); m_udp.Run(); @@ -365,6 +425,18 @@ void DiscordVoiceClient::HandleGatewaySpeaking(const VoiceGatewayMessage &m) { } m_ssrc_map[d.UserID] = d.SSRC; + + // track for DAVE + if (d.SSRC != 0) { + m_ssrc_user_map[d.SSRC] = d.UserID; + std::string uid_str = std::to_string(static_cast(d.UserID)); + m_connected_users.insert(uid_str); + if (m_dave) { + m_dave->AddConnectedUser(uid_str); + m_dave->ApplyKeyRatchetForSSRC(d.SSRC, d.UserID); + } + } + m_signal_speaking.emit(d); } @@ -452,9 +524,11 @@ void DiscordVoiceClient::HeartbeatThread() { m_log->trace("Heartbeat: {}", ms); - VoiceHeartbeatMessage msg; - msg.Nonce = ms; - m_ws.Send(msg); + nlohmann::json hb; + hb["op"] = VoiceGatewayOp::Heartbeat; + hb["d"]["t"] = ms; + hb["d"]["seq_ack"] = m_last_received_seq.load(); + m_ws.Send(hb); } } @@ -490,10 +564,21 @@ size_t GetPayloadOffset(const uint8_t *buf, size_t num_bytes) { } void DiscordVoiceClient::OnUDPData(std::vector data) { + if (data.size() < 44) return; + + // RTP version must be 2 + if (((data[0] >> 6) & 0x03) != 2) return; + + // only opus (payload type 120) + if ((data[1] & 0x7F) != 120) return; + uint32_t ssrc = (data[8] << 24) | (data[9] << 16) | (data[10] << 8) | (data[11] << 0); + + // ignore our own packets + if (ssrc == m_ssrc) return; std::array nonce = {}; std::memcpy(nonce.data(), data.data() + data.size() - sizeof(uint32_t), sizeof(uint32_t)); @@ -502,11 +587,51 @@ void DiscordVoiceClient::OnUDPData(std::vector data) { unsigned long long mlen = 0; if (crypto_aead_xchacha20poly1305_ietf_decrypt(data.data() + 12 + ext_size, &mlen, nullptr, data.data() + 12 + ext_size, data.size() - 12 - ext_size - sizeof(uint32_t), data.data(), 12 + ext_size, nonce.data(), m_secret_key.data())) { - // spdlog::get("voice")->trace("UDP payload decryption failure"); - } else { - const auto opus_offset = GetPayloadOffset(data.data(), data.size()); - Abaddon::Get().GetAudio().FeedMeOpus(ssrc, { data.data() + opus_offset, data.data() + 12 + ext_size + mlen }); + return; + } + + const auto opus_offset = GetPayloadOffset(data.data(), data.size()); + const uint8_t *payload_start = data.data() + opus_offset; + size_t payload_size = static_cast(12 + ext_size + mlen) - opus_offset; + + static const uint8_t OPUS_SILENCE[] = { 0xF8, 0xFF, 0xFE }; + + if (m_dave) { + // silence packets bypass DAVE per spec + if (payload_size == sizeof(OPUS_SILENCE) && + std::memcmp(payload_start, OPUS_SILENCE, sizeof(OPUS_SILENCE)) == 0) { + Abaddon::Get().GetAudio().FeedMeOpus(ssrc, { payload_start, payload_start + payload_size }); + return; + } + + if (m_dave->IsEnabled()) { + Snowflake uid; + if (auto it = m_ssrc_user_map.find(ssrc); it != m_ssrc_user_map.end()) + uid = it->second; + + auto *dec = m_dave->GetOrCreateDecryptor(ssrc, uid); + auto max_size = dec->GetMaxPlaintextByteSize(discord::dave::MediaType::Audio, payload_size); + std::vector plaintext(max_size); + size_t bytes_written = 0; + auto result = dec->Decrypt( + discord::dave::MediaType::Audio, + discord::dave::MakeArrayView(payload_start, payload_size), + discord::dave::MakeArrayView(plaintext.data(), plaintext.size()), + &bytes_written); + + if (result == discord::dave::IDecryptor::Success) { + Abaddon::Get().GetAudio().FeedMeOpus(ssrc, { plaintext.data(), plaintext.data() + bytes_written }); + } + return; + } else if (m_dave->IsDowngraded()) { + // passthrough + } else { + // DAVE session exists but not yet enabled, drop + return; + } } + + Abaddon::Get().GetAudio().FeedMeOpus(ssrc, { payload_start, payload_start + payload_size }); } void DiscordVoiceClient::OnDispatch() { @@ -521,6 +646,169 @@ void DiscordVoiceClient::OnDispatch() { OnGatewayMessage(msg); } +void DiscordVoiceClient::OnWebsocketBinaryMessage(const std::string &data) { + m_binary_dispatch_mutex.lock(); + m_binary_dispatch_queue.push(data); + m_binary_dispatcher.emit(); + m_binary_dispatch_mutex.unlock(); +} + +void DiscordVoiceClient::OnBinaryDispatch() { + m_binary_dispatch_mutex.lock(); + if (m_binary_dispatch_queue.empty()) { + m_binary_dispatch_mutex.unlock(); + return; + } + auto msg = std::move(m_binary_dispatch_queue.front()); + m_binary_dispatch_queue.pop(); + m_binary_dispatch_mutex.unlock(); + + // voice gateway v9 binary format: [2-byte seq BE][1-byte opcode][payload...] + if (msg.size() < 3) return; + + const auto *raw = reinterpret_cast(msg.data()); + int opcode = raw[2]; + const uint8_t *payload = raw + 3; + size_t payload_size = msg.size() - 3; + + if (!m_dave) return; + + switch (opcode) { + case static_cast(VoiceGatewayOp::MlsExternalSenderPackage): + m_dave->OnExternalSenderPackage(payload, payload_size); + break; + case static_cast(VoiceGatewayOp::MlsProposals): + m_dave->OnProposals(payload, payload_size); + break; + case static_cast(VoiceGatewayOp::MlsPrepareCommitTransition): + m_dave->OnAnnounceCommitTransition(payload, payload_size); + break; + case static_cast(VoiceGatewayOp::MlsWelcome): + m_dave->OnWelcome(payload, payload_size); + break; + default: + m_log->debug("Unhandled DAVE binary opcode: {}", opcode); + break; + } +} + +void DiscordVoiceClient::SendBinaryPayload(int opcode, const std::vector &data) { + std::string frame(1 + data.size(), '\0'); + frame[0] = static_cast(opcode); + std::memcpy(frame.data() + 1, data.data(), data.size()); + m_ws.SendBinary(frame); +} + +void DiscordVoiceClient::SendDaveReadyForTransition(int transitionId) { + nlohmann::json j; + j["op"] = VoiceGatewayOp::SecureFramesReadyForTransition; + j["d"]["transition_id"] = transitionId; + m_ws.Send(j); +} + +void DiscordVoiceClient::SendDaveInvalidCommitWelcome(int transitionId) { + nlohmann::json j; + j["op"] = VoiceGatewayOp::MlsInvalidCommitWelcome; + j["d"]["transition_id"] = transitionId; + m_ws.Send(j); +} + +void DiscordVoiceClient::HandleGatewayDavePrepareTransition(const VoiceGatewayMessage &m) { + if (!m_dave) return; + int version = m.Data.value("protocol_version", 0); + int transitionId = m.Data.value("transition_id", 0); + m_dave->OnPrepareTransition(version, transitionId); +} + +void DiscordVoiceClient::HandleGatewayDaveExecuteTransition(const VoiceGatewayMessage &m) { + if (!m_dave) return; + int transitionId = m.Data.value("transition_id", 0); + m_dave->OnExecuteTransition(transitionId); +} + +void DiscordVoiceClient::HandleGatewayDavePrepareEpoch(const VoiceGatewayMessage &m) { + int version = m.Data.value("protocol_version", 0); + int epoch = m.Data.value("epoch", 0); + if (!m_dave && version > 0) + EnsureDaveSession(static_cast(version)); + if (m_dave) + m_dave->OnPrepareEpoch(version, epoch); +} + +void DiscordVoiceClient::HandleGatewayClientConnect(const VoiceGatewayMessage &m) { + if (!m.Data.contains("user_ids")) return; + for (const auto &val : m.Data["user_ids"]) { + std::string uid_str = val.get(); + m_connected_users.insert(uid_str); + if (m_dave) + m_dave->AddConnectedUser(uid_str); + } +} + +void DiscordVoiceClient::HandleGatewaySessionUpdate(const VoiceGatewayMessage &m) { + if (!m.Data.contains("user_id")) return; + Snowflake uid = m.Data["user_id"].get(); + if (!uid.IsValid()) return; + + std::string uid_str = std::to_string(static_cast(uid)); + m_connected_users.insert(uid_str); + + uint32_t audio_ssrc = m.Data.value("audio_ssrc", 0u); + if (audio_ssrc != 0) { + m_ssrc_map[uid] = audio_ssrc; + m_ssrc_user_map[audio_ssrc] = uid; + } + + if (m_dave) + m_dave->AddConnectedUser(uid_str); +} + +void DiscordVoiceClient::HandleGatewayClientDisconnect(const VoiceGatewayMessage &m) { + if (!m.Data.contains("user_id")) return; + Snowflake uid = m.Data["user_id"].get(); + if (!uid.IsValid()) return; + + std::string uid_str = std::to_string(static_cast(uid)); + m_connected_users.erase(uid_str); + if (m_dave) + m_dave->RemoveConnectedUser(uid_str); + + // clean up ssrc mappings + for (auto it = m_ssrc_user_map.begin(); it != m_ssrc_user_map.end(); ++it) { + if (it->second == uid) { + m_ssrc_user_map.erase(it); + break; + } + } +} + +void DiscordVoiceClient::EnsureDaveSession(uint16_t protocolVersion) { + if (m_dave) return; + + m_log->info("Creating DAVE session, protocol version={}", protocolVersion); + + m_dave = std::make_unique(m_channel_id, m_user_id, m_ssrc_user_map); + m_dave->SetLocalSSRC(m_ssrc); + + for (const auto &uid : m_connected_users) + m_dave->AddConnectedUser(uid); + + m_dave->signal_send_binary().connect( + sigc::mem_fun(*this, &DiscordVoiceClient::SendBinaryPayload)); + m_dave->signal_send_ready_for_transition().connect( + sigc::mem_fun(*this, &DiscordVoiceClient::SendDaveReadyForTransition)); + m_dave->signal_send_invalid_commit_welcome().connect( + sigc::mem_fun(*this, &DiscordVoiceClient::SendDaveInvalidCommitWelcome)); + m_dave->signal_dave_state_changed().connect([this](bool enabled) { + if (enabled) + m_log->info("DAVE E2EE active"); + else + m_log->info("DAVE E2EE disabled"); + }); + + m_dave->Init(protocolVersion); +} + DiscordVoiceClient::type_signal_disconnected DiscordVoiceClient::signal_connected() { return m_signal_connected; } @@ -551,6 +839,14 @@ void to_json(nlohmann::json &j, const VoiceHeartbeatMessage &m) { j["d"] = m.Nonce; } +static nlohmann::json MakeHeartbeatV9(uint64_t nonce, int seq_ack) { + nlohmann::json j; + j["op"] = VoiceGatewayOp::Heartbeat; + j["d"]["t"] = nonce; + j["d"]["seq_ack"] = seq_ack; + return j; +} + void to_json(nlohmann::json &j, const VoiceIdentifyMessage &m) { j["op"] = VoiceGatewayOp::Identify; j["d"]["server_id"] = m.ServerID; @@ -561,6 +857,7 @@ void to_json(nlohmann::json &j, const VoiceIdentifyMessage &m) { j["d"]["streams"][0]["type"] = "video"; j["d"]["streams"][0]["rid"] = "100"; j["d"]["streams"][0]["quality"] = 100; + j["d"]["max_dave_protocol_version"] = 1; } void from_json(const nlohmann::json &j, VoiceReadyData::VoiceStream &m) { @@ -595,6 +892,7 @@ void to_json(nlohmann::json &j, const VoiceSelectProtocolMessage &m) { void from_json(const nlohmann::json &j, VoiceSessionDescriptionData &m) { JS_D("mode", m.Mode); JS_D("secret_key", m.SecretKey); + JS_ON("dave_protocol_version", m.DaveProtocolVersion); } void to_json(nlohmann::json &j, const VoiceSpeakingMessage &m) { diff --git a/src/discord/voiceclient.hpp b/src/discord/voiceclient.hpp index aa1014cf..ac1ea7f3 100644 --- a/src/discord/voiceclient.hpp +++ b/src/discord/voiceclient.hpp @@ -5,9 +5,11 @@ #include "snowflake.hpp" #include "waiter.hpp" #include "websocket.hpp" +#include "dave.hpp" #include #include #include +#include #include #include #include @@ -42,6 +44,7 @@ enum class VoiceGatewayOp : int { Resume = 7, Hello = 8, Resumed = 9, + ClientConnect = 11, ClientDisconnect = 13, SessionUpdate = 14, MediaSinkWants = 15, @@ -60,6 +63,7 @@ enum class VoiceGatewayOp : int { MlsCommitWelcome = 28, MlsPrepareCommitTransition = 29, MlsWelcome = 30, + MlsInvalidCommitWelcome = 31, }; struct VoiceGatewayMessage { @@ -129,6 +133,7 @@ struct VoiceSessionDescriptionData { // std::string MediaSessionID; std::string Mode; std::array SecretKey; + int DaveProtocolVersion = 0; friend void from_json(const nlohmann::json &j, VoiceSessionDescriptionData &m); }; @@ -209,6 +214,7 @@ class DiscordVoiceClient { void SetEndpoint(std::string_view endpoint); void SetToken(std::string_view token); void SetServerID(Snowflake id); + void SetChannelID(Snowflake id); void SetUserID(Snowflake id); // todo serialize @@ -236,6 +242,12 @@ class DiscordVoiceClient { void HandleGatewayReady(const VoiceGatewayMessage &m); void HandleGatewaySessionDescription(const VoiceGatewayMessage &m); void HandleGatewaySpeaking(const VoiceGatewayMessage &m); + void HandleGatewayDavePrepareTransition(const VoiceGatewayMessage &m); + void HandleGatewayDaveExecuteTransition(const VoiceGatewayMessage &m); + void HandleGatewayDavePrepareEpoch(const VoiceGatewayMessage &m); + void HandleGatewayClientConnect(const VoiceGatewayMessage &m); + void HandleGatewayClientDisconnect(const VoiceGatewayMessage &m); + void HandleGatewaySessionUpdate(const VoiceGatewayMessage &m); void Identify(); void Discovery(); @@ -244,6 +256,13 @@ class DiscordVoiceClient { void OnWebsocketOpen(); void OnWebsocketClose(const ix::WebSocketCloseInfo &info); void OnWebsocketMessage(const std::string &str); + void OnWebsocketBinaryMessage(const std::string &data); + + void OnBinaryDispatch(); + void SendBinaryPayload(int opcode, const std::vector &data); + void SendDaveReadyForTransition(int transitionId); + void SendDaveInvalidCommitWelcome(int transitionId); + void EnsureDaveSession(uint16_t protocolVersion); void HeartbeatThread(); void KeepaliveThread(); @@ -269,6 +288,7 @@ class DiscordVoiceClient { uint32_t m_ssrc; int m_heartbeat_msec; + std::atomic m_last_received_seq { -1 }; Waiter m_heartbeat_waiter; std::thread m_heartbeat_thread; @@ -282,8 +302,16 @@ class DiscordVoiceClient { std::queue m_dispatch_queue; std::mutex m_dispatch_mutex; + Glib::Dispatcher m_binary_dispatcher; + std::queue m_binary_dispatch_queue; + std::mutex m_binary_dispatch_mutex; + void OnDispatch(); + std::unique_ptr m_dave; + std::unordered_map m_ssrc_user_map; + std::set m_connected_users; + std::array m_opus_buffer; std::shared_ptr m_log; diff --git a/src/discord/websocket.cpp b/src/discord/websocket.cpp index 565c94cf..6acf0171 100644 --- a/src/discord/websocket.cpp +++ b/src/discord/websocket.cpp @@ -5,7 +5,6 @@ #include #include - Websocket::Websocket(const std::string &id) : m_close_info { 1000, "Normal", false } { if (m_log = spdlog::get(id); !m_log) { @@ -46,6 +45,10 @@ void Websocket::SetPrintMessages(bool show) noexcept { m_print_messages = show; } +void Websocket::SetSeparateBinaryMessages(bool separate) noexcept { + m_separate_binary = separate; +} + void Websocket::Stop() { m_log->debug("Stopping with default close code"); Stop(ix::WebSocketCloseConstants::kNormalClosureCode); @@ -71,6 +74,10 @@ void Websocket::Send(const nlohmann::json &j) { Send(j.dump()); } +void Websocket::SendBinary(const std::string &data) { + m_websocket->sendBinary(data); +} + void Websocket::OnMessage(const ix::WebSocketMessagePtr &msg) { switch (msg->type) { case ix::WebSocketMessageType::Open: { @@ -84,7 +91,10 @@ void Websocket::OnMessage(const ix::WebSocketMessagePtr &msg) { m_close_dispatcher.emit(); } break; case ix::WebSocketMessageType::Message: { - m_signal_message.emit(msg->str); + if (m_separate_binary && msg->binary) + m_signal_binary_message.emit(msg->str); + else + m_signal_message.emit(msg->str); } break; case ix::WebSocketMessageType::Error: { m_log->error("Websocket error: Status: {} Reason: {}", msg->errorInfo.http_status, msg->errorInfo.reason); @@ -105,3 +115,7 @@ Websocket::type_signal_close Websocket::signal_close() { Websocket::type_signal_message Websocket::signal_message() { return m_signal_message; } + +Websocket::type_signal_binary_message Websocket::signal_binary_message() { + return m_signal_binary_message; +} diff --git a/src/discord/websocket.hpp b/src/discord/websocket.hpp index a77bf553..77b6cfe1 100644 --- a/src/discord/websocket.hpp +++ b/src/discord/websocket.hpp @@ -17,9 +17,11 @@ class Websocket { bool GetPrintMessages() const noexcept; void SetPrintMessages(bool show) noexcept; + void SetSeparateBinaryMessages(bool separate) noexcept; void Send(const std::string &str); void Send(const nlohmann::json &j); + void SendBinary(const std::string &data); void Stop(); void Stop(uint16_t code); @@ -33,17 +35,21 @@ class Websocket { using type_signal_open = sigc::signal; using type_signal_close = sigc::signal; using type_signal_message = sigc::signal; + using type_signal_binary_message = sigc::signal; type_signal_open signal_open(); type_signal_close signal_close(); type_signal_message signal_message(); + type_signal_binary_message signal_binary_message(); private: type_signal_open m_signal_open; type_signal_close m_signal_close; type_signal_message m_signal_message; + type_signal_binary_message m_signal_binary_message; bool m_print_messages = true; + bool m_separate_binary = false; Glib::Dispatcher m_open_dispatcher; Glib::Dispatcher m_close_dispatcher; diff --git a/subprojects/libdave b/subprojects/libdave new file mode 160000 index 00000000..52cd56dc --- /dev/null +++ b/subprojects/libdave @@ -0,0 +1 @@ +Subproject commit 52cd56dc550f447fb354b3a06c9e2d2e2a4309c6 diff --git a/subprojects/mlspp b/subprojects/mlspp new file mode 160000 index 00000000..1cc50a12 --- /dev/null +++ b/subprojects/mlspp @@ -0,0 +1 @@ +Subproject commit 1cc50a124a3bc4e143a787ec934280dc70c1034d