Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 110 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
use process_api::get_process_info;
use session::{ApplicationSession, EndPointSession, Session};
use windows::{
core::{Interface, PWSTR},
core::Interface,
Win32::{
Media::Audio::{
eCapture, eMultimedia, eRender, Endpoints::IAudioEndpointVolume, IAudioSessionControl, IAudioSessionControl2, IAudioSessionEnumerator, IAudioSessionManager2, IMMDevice, IMMDeviceCollection, IMMDeviceEnumerator, ISimpleAudioVolume, MMDeviceEnumerator, DEVICE_STATE_ACTIVE
eCapture, eMultimedia, eRender, Endpoints::IAudioEndpointVolume, IAudioSessionControl, IAudioSessionControl2, IAudioSessionEnumerator, IAudioSessionManager2, IMMDevice, IMMDeviceCollection, IMMDeviceEnumerator, ISimpleAudioVolume, MMDeviceEnumerator, PKEY_AudioEndpoint_FormFactor, DEVICE_STATE_ACTIVE
},
System::{
Com::{CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, CLSCTX_INPROC_SERVER, COINIT_APARTMENTTHREADED, COINIT_MULTITHREADED, STGM_READ, CoTaskMemFree},
ProcessStatus::K32GetProcessImageFileNameA,
Threading::{OpenProcess, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ},
Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, CLSCTX_INPROC_SERVER, COINIT_APARTMENTTHREADED, COINIT_MULTITHREADED, STGM_READ, CoTaskMemFree},
},
Devices::FunctionDiscovery::PKEY_Device_FriendlyName,
UI::Shell::PropertiesSystem::{PROPERTYKEY, PropVariantToStringAlloc},
UI::Shell::PropertiesSystem::{IPropertyStore, PropVariantToStringAlloc},
},
};
use std::process::exit;
Expand All @@ -22,6 +20,8 @@ mod process_api;

mod session;

const FORM_FACTOR_SPDIF: i32 = 8;

// Helper function to get device friendly name using PropVariantToStringAlloc
fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String {
unsafe {
Expand All @@ -43,7 +43,7 @@ fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String {
CoTaskMemFree(Some(buffer.as_ptr().cast()));
}
}
Err(e) => {
Err(_err) => {
// Log error if needed: eprintln!("PropVariantToStringAlloc failed: {:?}", e);
// Keep the fallback name if conversion fails
}
Expand All @@ -56,6 +56,24 @@ fn get_device_friendly_name(device: &IMMDevice, fallback_name: &str) -> String {
}
}

fn get_device_form_factor(property_store: &IPropertyStore) -> Option<i32> {
unsafe {
if let Ok(prop_variant) = property_store.GetValue(&PKEY_AudioEndpoint_FormFactor) {
if let Ok(buffer) = PropVariantToStringAlloc(&prop_variant) {
if !buffer.is_null() {
let form_factor = buffer
.to_string()
.ok()
.and_then(|val| val.parse::<i32>().ok());
CoTaskMemFree(Some(buffer.as_ptr().cast()));
return form_factor;
}
}
}
None
}
}

pub struct AudioController {
default_device: Option<IMMDevice>,
default_input_device: Option<IMMDevice>,
Expand All @@ -65,6 +83,7 @@ pub struct AudioController {
default_input_id: Option<String>,
}

#[derive(Debug)]
pub enum CoinitMode {
MultiTreaded,
ApartmentThreaded
Expand All @@ -79,6 +98,7 @@ impl AudioController {
CoinitMode::MultiTreaded => {coinit = COINIT_MULTITHREADED}
}
}

CoInitializeEx(None, coinit).unwrap_or_else(|err| {
eprintln!("ERROR: Couldn't initialize windows connection: {err}");
error!("ERROR: Couldn't initialize windows connection: {}", err);
Expand Down Expand Up @@ -109,14 +129,9 @@ impl AudioController {


pub unsafe fn GetAllProcessSessions(&mut self) {
// Initialize COM library
// if let Err(err) = CoInitializeEx(Some(std::ptr::null_mut()), COINIT_MULTITHREADED) {
// eprintln!("ERROR: Failed to initialize COM library... {err}");
// return;
// }

// Get the device enumerator
let device_enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_INPROC_SERVER).unwrap_or_else(|err| {
let device_enumerator_result = CoCreateInstance::<_, IMMDeviceEnumerator>(&MMDeviceEnumerator, None, CLSCTX_INPROC_SERVER);
let device_enumerator: IMMDeviceEnumerator = device_enumerator_result.unwrap_or_else(|err| {
eprintln!("ERROR: Couldn't create device enumerator... {err}");
error!("ERROR: Couldn't create device enumerator... {}", err);
exit(1);
Expand All @@ -136,36 +151,93 @@ impl AudioController {
});

for device_index in 0..device_count {
let device: IMMDevice = device_collection.Item(device_index).unwrap_or_else(|err| {
eprintln!("ERROR: Couldn't get device at index {device_index}... {err}");
error!("ERROR: Couldn't get device at index {}... {}", device_index, err);
exit(1);
});
let device: IMMDevice = match device_collection.Item(device_index) {
Ok(dev) => dev,
Err(err) => {
eprintln!("WARNING: Skipping device {} - couldn't get device: {}", device_index, err);
error!("WARNING: Skipping device {} - couldn't get device: {}", device_index, err);
continue;
}
};

// Get device ID for error reporting
let device_id = match device.GetId() {
Ok(id) => match id.to_string() {
Ok(id_str) => id_str,
Err(_) => {
eprintln!("WARNING: Skipping device {} - couldn't convert device ID to string", device_index);
error!("WARNING: Skipping device {} - couldn't convert device ID to string", device_index);
continue;
}
},
Err(err) => {
eprintln!("WARNING: Skipping device {} - couldn't get device ID: {}", device_index, err);
error!("WARNING: Skipping device {} - couldn't get device ID: {}", device_index, err);
continue;
}
};

// Attempt to open property store - if this fails, do not proceed
let property_store = match device.OpenPropertyStore(STGM_READ) {
Ok(store) => store,
Err(err) => {
let device_name = get_device_friendly_name(&device, &format!("Device {}", device_index));
eprintln!("ERROR: Skipping device '{}' (ID: {}) - couldn't open property store: {}", device_name, device_id, err);
error!("ERROR: Skipping device '{}' (ID: {}) - couldn't open property store: {}", device_name, device_id, err);
continue;
}
};

let device_name = get_device_friendly_name(&device, &format!("Device {}", device_index));

// Skip SPDIF/digital outputs if detected by form factor
if let Some(form_factor) = get_device_form_factor(&property_store) {
if form_factor == FORM_FACTOR_SPDIF {
continue;
}
}

let session_manager2: IAudioSessionManager2 = device.Activate(CLSCTX_INPROC_SERVER, None).unwrap_or_else(|err| {
eprintln!("ERROR: Couldn't get AudioSessionManager for enumerating over processes... {err}");
error!("ERROR: Couldn't get AudioSessionManager for enumerating over processes... {}", err);
exit(1);
});
let session_manager2: IAudioSessionManager2 = match device.Activate(CLSCTX_INPROC_SERVER, None) {
Ok(mgr) => mgr,
Err(err) => {
eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't activate AudioSessionManager: {}", device_name, device_id, err);
error!("WARNING: Skipping device '{}' (ID: {}) - couldn't activate AudioSessionManager: {}", device_name, device_id, err);
continue;
}
};

let session_enumerator: IAudioSessionEnumerator = session_manager2.GetSessionEnumerator().unwrap_or_else(|err| {
eprintln!("ERROR: Couldn't get session enumerator... {err}");
error!("ERROR: Couldn't get session enumerator... {}", err);
exit(1);
});
let session_enum_result = session_manager2.GetSessionEnumerator();
// If GetSessionEnumerator fails for this device, skip it
let session_enumerator: IAudioSessionEnumerator = match session_enum_result {
Ok(enumerator) => enumerator,
Err(err) => {
eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session enumerator: {}", device_name, device_id, err);
error!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session enumerator: {}", device_name, device_id, err);
continue;
}
};

for i in 0..session_enumerator.GetCount().unwrap() {
let session_count = match session_enumerator.GetCount() {
Ok(count) => count,
Err(err) => {
eprintln!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session count: {}", device_name, device_id, err);
error!("WARNING: Skipping device '{}' (ID: {}) - couldn't get session count: {}", device_name, device_id, err);
continue;
}
};

for i in 0..session_count {
let normal_session_control: Option<IAudioSessionControl> = session_enumerator.GetSession(i).ok();
if normal_session_control.is_none() {
eprintln!("ERROR: Couldn't get session control of audio session...");
error!("ERROR: Couldn't get session control of audio session...");
eprintln!("ERROR: Couldn't get session control of audio session for device '{}' (ID: {})", device_name, device_id);
error!("ERROR: Couldn't get session control of audio session for device '{}' (ID: {})", device_name, device_id);
continue;
}

let session_control: Option<IAudioSessionControl2> = normal_session_control.unwrap().cast().ok();
if session_control.is_none() {
eprintln!("ERROR: Couldn't convert from normal session control to session control 2");
error!("ERROR: Couldn't convert from normal session control to session control 2");
eprintln!("ERROR: Couldn't convert from normal session control to session control 2 for device '{}' (ID: {})", device_name, device_id);
error!("ERROR: Couldn't convert from normal session control to session control 2 for device '{}' (ID: {})", device_name, device_id);
continue;
}

Expand All @@ -179,19 +251,17 @@ impl AudioController {
info.process_name.clone()
},
Err(_err) => {
eprintln!("ERROR: Couldn't get process info for pid {}", pid);
error!("ERROR: Couldn't get process info for pid {}", pid);
eprintln!("ERROR: Couldn't get process info for pid {} on device '{}' (ID: {})", pid, device_name, device_id);
error!("ERROR: Couldn't get process info for pid {} on device '{}' (ID: {})", pid, device_name, device_id);
continue;
}
};



let audio_control: ISimpleAudioVolume = match session_control.unwrap().cast() {
Ok(data) => data,
Err(err) => {
eprintln!("ERROR: Couldn't get the simpleaudiovolume from session controller: {err}");
error!("ERROR: Couldn't get the simpleaudiovolume from session controller: {}", err);
eprintln!("ERROR: Couldn't get the simpleaudiovolume from session controller for device '{}' (ID: {}): {}", device_name, device_id, err);
error!("ERROR: Couldn't get the simpleaudiovolume from session controller for device '{}' (ID: {}): {}", device_name, device_id, err);
continue;
}
};
Expand Down