From 7a466f62a7cfa66dc9146209b7d9bdf10c7c1b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Fr=C3=B8yland?= <81354124+Andreas-Froyland@users.noreply.github.com> Date: Sun, 18 Jan 2026 21:48:54 +0100 Subject: [PATCH] skip digital only devices as they are not needed in the context of volume control --- src/lib.rs | 150 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 110 insertions(+), 40 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bab0a9e..cbb4fee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,16 @@ use process_api::get_process_info; use session::{ApplicationSession, EndPointSession, Session}; use windows::{ - core::{Interface, PWSTR}, + core::Interface, Win32::{ Media::Audio::{ - eCapture, eMultimedia, eRender, Endpoints::IAudioEndpointVolume, IAudioSessionControl, IAudioSessionControl2, IAudioSessionEnumerator, IAudioSessionManager2, IMMDevice, IMMDeviceCollection, IMMDeviceEnumerator, ISimpleAudioVolume, MMDeviceEnumerator, DEVICE_STATE_ACTIVE + eCapture, eMultimedia, eRender, Endpoints::IAudioEndpointVolume, IAudioSessionControl, IAudioSessionControl2, IAudioSessionEnumerator, IAudioSessionManager2, IMMDevice, IMMDeviceCollection, IMMDeviceEnumerator, ISimpleAudioVolume, MMDeviceEnumerator, PKEY_AudioEndpoint_FormFactor, DEVICE_STATE_ACTIVE }, System::{ - Com::{CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, CLSCTX_INPROC_SERVER, COINIT_APARTMENTTHREADED, COINIT_MULTITHREADED, STGM_READ, CoTaskMemFree}, - ProcessStatus::K32GetProcessImageFileNameA, - Threading::{OpenProcess, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ}, + Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, CLSCTX_INPROC_SERVER, COINIT_APARTMENTTHREADED, COINIT_MULTITHREADED, STGM_READ, CoTaskMemFree}, }, Devices::FunctionDiscovery::PKEY_Device_FriendlyName, - UI::Shell::PropertiesSystem::{PROPERTYKEY, PropVariantToStringAlloc}, + UI::Shell::PropertiesSystem::{IPropertyStore, PropVariantToStringAlloc}, }, }; use std::process::exit; @@ -22,6 +20,8 @@ mod process_api; mod session; +const FORM_FACTOR_SPDIF: i32 = 8; + // Helper function to get device friendly name using PropVariantToStringAlloc fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String { unsafe { @@ -43,7 +43,7 @@ fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String { CoTaskMemFree(Some(buffer.as_ptr().cast())); } } - Err(e) => { + Err(_err) => { // Log error if needed: eprintln!("PropVariantToStringAlloc failed: {:?}", e); // Keep the fallback name if conversion fails } @@ -56,6 +56,24 @@ fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String { } } +fn get_device_form_factor(property_store: &IPropertyStore) -> Option { + unsafe { + if let Ok(prop_variant) = property_store.GetValue(&PKEY_AudioEndpoint_FormFactor) { + if let Ok(buffer) = PropVariantToStringAlloc(&prop_variant) { + if !buffer.is_null() { + let form_factor = buffer + .to_string() + .ok() + .and_then(|val| val.parse::().ok()); + CoTaskMemFree(Some(buffer.as_ptr().cast())); + return form_factor; + } + } + } + None + } +} + pub struct AudioController { default_device: Option, default_input_device: Option, @@ -65,6 +83,7 @@ pub struct AudioController { default_input_id: Option, } +#[derive(Debug)] pub enum CoinitMode { MultiTreaded, ApartmentThreaded @@ -79,6 +98,7 @@ impl AudioController { CoinitMode::MultiTreaded => {coinit = COINIT_MULTITHREADED} } } + CoInitializeEx(None, coinit).unwrap_or_else(|err| { eprintln!("ERROR: Couldn't initialize windows connection: {err}"); error!("ERROR: Couldn't initialize windows connection: {}", err); @@ -109,14 +129,9 @@ impl AudioController { pub unsafe fn GetAllProcessSessions(&mut self) { - // Initialize COM library - // if let Err(err) = CoInitializeEx(Some(std::ptr::null_mut()), COINIT_MULTITHREADED) { - // eprintln!("ERROR: Failed to initialize COM library... {err}"); - // return; - // } - // Get the device enumerator - let device_enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_INPROC_SERVER).unwrap_or_else(|err| { + let device_enumerator_result = CoCreateInstance::<_, IMMDeviceEnumerator>(&MMDeviceEnumerator, None, CLSCTX_INPROC_SERVER); + let device_enumerator: IMMDeviceEnumerator = device_enumerator_result.unwrap_or_else(|err| { eprintln!("ERROR: Couldn't create device enumerator... {err}"); error!("ERROR: Couldn't create device enumerator... {}", err); exit(1); @@ -136,36 +151,93 @@ impl AudioController { }); for device_index in 0..device_count { - let device: IMMDevice = device_collection.Item(device_index).unwrap_or_else(|err| { - eprintln!("ERROR: Couldn't get device at index {device_index}... {err}"); - error!("ERROR: Couldn't get device at index {}... {}", device_index, err); - exit(1); - }); + let device: IMMDevice = match device_collection.Item(device_index) { + Ok(dev) => dev, + Err(err) => { + eprintln!("WARNING: Skipping device {} - couldn't get device: {}", device_index, err); + error!("WARNING: Skipping device {} - couldn't get device: {}", device_index, err); + continue; + } + }; + + // Get device ID for error reporting + let device_id = match device.GetId() { + Ok(id) => match id.to_string() { + Ok(id_str) => id_str, + Err(_) => { + eprintln!("WARNING: Skipping device {} - couldn't convert device ID to string", device_index); + error!("WARNING: Skipping device {} - couldn't convert device ID to string", device_index); + continue; + } + }, + Err(err) => { + eprintln!("WARNING: Skipping device {} - couldn't get device ID: {}", device_index, err); + error!("WARNING: Skipping device {} - couldn't get device ID: {}", device_index, err); + continue; + } + }; + + // Attempt to open property store - if this fails, do not proceed + let property_store = match device.OpenPropertyStore(STGM_READ) { + Ok(store) => store, + Err(err) => { + let device_name = get_device_friendly_name(&device, &format!("Device {}", device_index)); + eprintln!("ERROR: Skipping device '{}' (ID: {}) - couldn't open property store: {}", device_name, device_id, err); + error!("ERROR: Skipping device '{}' (ID: {}) - couldn't open property store: {}", device_name, device_id, err); + continue; + } + }; + + let device_name = get_device_friendly_name(&device, &format!("Device {}", device_index)); + + // Skip SPDIF/digital outputs if detected by form factor + if let Some(form_factor) = get_device_form_factor(&property_store) { + if form_factor == FORM_FACTOR_SPDIF { + continue; + } + } - let session_manager2: IAudioSessionManager2 = device.Activate(CLSCTX_INPROC_SERVER, None).unwrap_or_else(|err| { - eprintln!("ERROR: Couldn't get AudioSessionManager for enumerating over processes... {err}"); - error!("ERROR: Couldn't get AudioSessionManager for enumerating over processes... {}", err); - exit(1); - }); + let session_manager2: IAudioSessionManager2 = match device.Activate(CLSCTX_INPROC_SERVER, None) { + Ok(mgr) => mgr, + Err(err) => { + eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't activate AudioSessionManager: {}", device_name, device_id, err); + error!("WARNING: Skipping device '{}' (ID: {}) - couldn't activate AudioSessionManager: {}", device_name, device_id, err); + continue; + } + }; - let session_enumerator: IAudioSessionEnumerator = session_manager2.GetSessionEnumerator().unwrap_or_else(|err| { - eprintln!("ERROR: Couldn't get session enumerator... {err}"); - error!("ERROR: Couldn't get session enumerator... {}", err); - exit(1); - }); + let session_enum_result = session_manager2.GetSessionEnumerator(); + // If GetSessionEnumerator fails for this device, skip it + let session_enumerator: IAudioSessionEnumerator = match session_enum_result { + Ok(enumerator) => enumerator, + Err(err) => { + eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session enumerator: {}", device_name, device_id, err); + error!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session enumerator: {}", device_name, device_id, err); + continue; + } + }; - for i in 0..session_enumerator.GetCount().unwrap() { + let session_count = match session_enumerator.GetCount() { + Ok(count) => count, + Err(err) => { + eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session count: {}", device_name, device_id, err); + error!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session count: {}", device_name, device_id, err); + continue; + } + }; + + for i in 0..session_count { let normal_session_control: Option = session_enumerator.GetSession(i).ok(); if normal_session_control.is_none() { - eprintln!("ERROR: Couldn't get session control of audio session..."); - error!("ERROR: Couldn't get session control of audio session..."); + eprintln!("ERROR: Couldn't get session control of audio session for device '{}' (ID: {})", device_name, device_id); + error!("ERROR: Couldn't get session control of audio session for device '{}' (ID: {})", device_name, device_id); continue; } let session_control: Option = normal_session_control.unwrap().cast().ok(); if session_control.is_none() { - eprintln!("ERROR: Couldn't convert from normal session control to session control 2"); - error!("ERROR: Couldn't convert from normal session control to session control 2"); + eprintln!("ERROR: Couldn't convert from normal session control to session control 2 for device '{}' (ID: {})", device_name, device_id); + error!("ERROR: Couldn't convert from normal session control to session control 2 for device '{}' (ID: {})", device_name, device_id); continue; } @@ -179,19 +251,17 @@ impl AudioController { info.process_name.clone() }, Err(_err) => { - eprintln!("ERROR: Couldn't get process info for pid {}", pid); - error!("ERROR: Couldn't get process info for pid {}", pid); + eprintln!("ERROR: Couldn't get process info for pid {} on device '{}' (ID: {})", pid, device_name, device_id); + error!("ERROR: Couldn't get process info for pid {} on device '{}' (ID: {})", pid, device_name, device_id); continue; } }; - - let audio_control: ISimpleAudioVolume = match session_control.unwrap().cast() { Ok(data) => data, Err(err) => { - eprintln!("ERROR: Couldn't get the simpleaudiovolume from session controller: {err}"); - error!("ERROR: Couldn't get the simpleaudiovolume from session controller: {}", err); + eprintln!("ERROR: Couldn't get the simpleaudiovolume from session controller for device '{}' (ID: {}): {}", device_name, device_id, err); + error!("ERROR: Couldn't get the simpleaudiovolume from session controller for device '{}' (ID: {}): {}", device_name, device_id, err); continue; } };