diff --git a/src/rust/src/webrtc/audio_device_module.rs b/src/rust/src/webrtc/audio_device_module.rs index 26266e59..389cc44c 100644 --- a/src/rust/src/webrtc/audio_device_module.rs +++ b/src/rust/src/webrtc/audio_device_module.rs @@ -6,14 +6,17 @@ use crate::webrtc; use crate::webrtc::audio_device_module_utils::{copy_and_truncate_string, DeviceCollectionWrapper}; use crate::webrtc::ffi::audio_device_module::RffiAudioTransport; -use anyhow::anyhow; +use anyhow::{anyhow, bail}; use cubeb::{Context, DeviceId, DeviceType, MonoFrame, Stream, StreamPrefs}; use cubeb_core::{log_enabled, set_logging, LogLevel}; use lazy_static::lazy_static; use regex::Regex; use std::collections::{HashMap, VecDeque}; -use std::ffi::{c_uchar, c_void, CStr, CString}; -use std::sync::{Arc, Mutex, RwLock}; +use std::ffi::{c_uchar, c_void, CStr}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, RwLock, +}; use std::time::{Duration, Instant}; #[cfg(target_os = "windows")] use windows::Win32::System::Com; @@ -65,12 +68,22 @@ pub struct AudioDeviceModule { audio_transport: Arc>, cubeb_ctx: Option, initialized: bool, + // Note that the DeviceIds must not outlive the cubeb_ctx. playout_device: Option, recording_device: Option, + // Note that the streams must not outlive the cubeb_ctx. output_stream: Option>, input_stream: Option>, playing: bool, recording: bool, + // Note that the caches must not outlive the cubeb_ctx. + input_device_cache: Option, + output_device_cache: Option, + // Flags to track whether we need to refresh caches. + // As these are shared with threads in cubeb, we create them once at ADM init + // and never free them. + pending_input_device_refresh: &'static AtomicBool, + pending_output_device_refresh: &'static AtomicBool, } impl Default for AudioDeviceModule { @@ -87,6 +100,12 @@ impl Default for AudioDeviceModule { input_stream: None, playing: false, recording: false, + input_device_cache: None, + output_device_cache: None, + // Start these both as true to request a cache refresh, and leak them for reasons + // mentioned at the struct declaration site. + pending_input_device_refresh: Box::leak(Box::new(AtomicBool::new(true))), + pending_output_device_refresh: Box::leak(Box::new(AtomicBool::new(true))), } } } @@ -108,7 +127,7 @@ const ADM_MAX_DEVICE_NAME_SIZE: usize = 128; const ADM_MAX_GUID_SIZE: usize = 128; /// Arbitrary string to uniquely identify ringrtc for creating the cubeb object. -const ADM_CONTEXT: &str = "ringrtc"; +const ADM_CONTEXT: &CStr = c"ringrtc"; const SAMPLE_FREQUENCY: u32 = 48_000; // Target sample latency. The actual sample latency will @@ -213,6 +232,34 @@ impl AudioDeviceModule { 0 } + /// Safety: Must be called with a valid |flag| pointer. (NULL is okay.) + unsafe extern "C" fn device_changed(_ctx: *mut cubeb::ffi::cubeb, flag: *mut c_void) { + // Flag that an update is needed; this will be processed on the next enumerate_devices call. + if let Some(b) = (flag as *mut AtomicBool).as_ref() { + b.store(true, Ordering::SeqCst) + } + } + + fn register_device_collection_changed( + &mut self, + device_type: DeviceType, + flag: &'static AtomicBool, + ) -> anyhow::Result<()> { + let ctx = match &self.cubeb_ctx { + Some(ctx) => ctx, + None => bail!("Cannot register device changed callback without a ctx"), + }; + unsafe { + // Safety: |callback| will remain a valid pointer for the lifetime of the program, + // as will |flag| (since it's static) + Ok(ctx.register_device_collection_changed( + device_type, + Some(AudioDeviceModule::device_changed), + flag.as_ptr() as *mut c_void, + )?) + } + } + // Main initialization and termination pub fn init(&mut self) -> i32 { // Don't bother re-initializing. @@ -238,8 +285,7 @@ impl AudioDeviceModule { return -1; } } - let ctx_name = CString::new(ADM_CONTEXT).unwrap(); - match Context::init(Some(ctx_name.as_c_str()), None) { + match Context::init(Some(ADM_CONTEXT), None) { Ok(ctx) => { info!( "Successfully initialized cubeb backend {}", @@ -247,13 +293,33 @@ impl AudioDeviceModule { ); self.cubeb_ctx = Some(ctx); self.initialized = true; - 0 } Err(e) => { error!("Failed to initialize: {}", e); - -1 + return -1; } } + if let Err(e) = self.register_device_collection_changed( + DeviceType::INPUT, + self.pending_input_device_refresh, + ) { + error!("Failed to register input device callback: {}", e); + return -1; + } + if let Err(e) = self.register_device_collection_changed( + DeviceType::OUTPUT, + self.pending_output_device_refresh, + ) { + error!("Failed to register input device callback: {}", e); + return -1; + } + // Caches are not set up, so request a refresh. + self.pending_input_device_refresh + .store(true, Ordering::SeqCst); + self.pending_output_device_refresh + .store(true, Ordering::SeqCst); + self.initialized = true; + 0 } pub fn backend_name(&self) -> Option { @@ -269,9 +335,37 @@ impl AudioDeviceModule { if self.playing { self.stop_playout(); } - // Cause these to Drop + // Cause these to Drop. self.input_stream = None; self.output_stream = None; + // Ensure these are not reused. + self.playout_device = None; + self.recording_device = None; + self.input_device_cache = None; + self.output_device_cache = None; + if let Some(ctx) = &self.cubeb_ctx { + // Clear callbacks. + unsafe { + // Safety: We are calling this with None, which will unset the callback, + // so passing null is safe. + if let Err(e) = ctx.register_device_collection_changed( + DeviceType::INPUT, + None, + std::ptr::null_mut(), + ) { + warn!("failed to reset input callback: {}", e); + } + if let Err(e) = ctx.register_device_collection_changed( + DeviceType::OUTPUT, + None, + std::ptr::null_mut(), + ) { + warn!("failed to reset output callback: {}", e); + } + } + } + // Invalidate the ctx (note that any references to it, like `DeviceId`s, + // must have already been dropped). self.cubeb_ctx = None; self.initialized = false; #[cfg(target_os = "windows")] @@ -288,15 +382,52 @@ impl AudioDeviceModule { self.initialized } + fn refresh_device_cache(&mut self, device_type: DeviceType) -> anyhow::Result<()> { + let ctx = match &self.cubeb_ctx { + Some(ctx) => ctx, + None => bail!("cannot refresh device cache without a ctx"), + }; + let devices = ctx.enumerate_devices(device_type)?; + for device in devices.iter() { + info!( + "{:?} device: ({})", + device_type, + AudioDeviceModule::device_str(device) + ); + } + let collection = DeviceCollectionWrapper::new(devices); + match device_type { + DeviceType::INPUT => self.input_device_cache = Some(collection), + DeviceType::OUTPUT => self.output_device_cache = Some(collection), + _ => bail!("Bad device type {:?}", device_type), + } + Ok(()) + } + fn enumerate_devices( - &self, + &mut self, device_type: DeviceType, - ) -> anyhow::Result { - match &self.cubeb_ctx { - Some(ctx) => Ok(DeviceCollectionWrapper::new( - ctx.enumerate_devices(device_type)?, - )), - None => Err(anyhow!("Cannot enumerate devices without a cubeb ctx"))?, + ) -> anyhow::Result<&DeviceCollectionWrapper> { + let pending_update = match device_type { + DeviceType::INPUT => self + .pending_input_device_refresh + .swap(false, Ordering::SeqCst), + DeviceType::OUTPUT => self + .pending_output_device_refresh + .swap(false, Ordering::SeqCst), + _ => bail!("Bad device type {:?}", device_type), + }; + if pending_update { + self.refresh_device_cache(device_type)?; + } + let collection = match device_type { + DeviceType::INPUT => self.input_device_cache.as_ref(), + DeviceType::OUTPUT => self.output_device_cache.as_ref(), + _ => bail!("Bad device type {:?}", device_type), + }; + match collection { + Some(c) => Ok(c), + None => Err(anyhow!("No {:?} collection found", device_type)), } } @@ -345,7 +476,7 @@ impl AudioDeviceModule { } // Device enumeration - pub fn playout_devices(&self) -> i16 { + pub fn playout_devices(&mut self) -> i16 { match self.enumerate_devices(DeviceType::OUTPUT) { Ok(device_collection) => device_collection.count().try_into().unwrap_or(-1), Err(e) => { @@ -355,7 +486,7 @@ impl AudioDeviceModule { } } - pub fn recording_devices(&self) -> i16 { + pub fn recording_devices(&mut self) -> i16 { match self.enumerate_devices(DeviceType::INPUT) { Ok(device_collection) => device_collection.count().try_into().unwrap_or(-1), Err(e) => { @@ -367,12 +498,12 @@ impl AudioDeviceModule { fn copy_name_and_id( index: u16, - devices: DeviceCollectionWrapper, + devices: &DeviceCollectionWrapper, name_out: webrtc::ptr::Borrowed, guid_out: webrtc::ptr::Borrowed, ) -> anyhow::Result<()> { if let Some(d) = devices.get(index.into()) { - if let Some(name) = d.friendly_name() { + if let Some(name) = &d.friendly_name { let mut name_copy = name.to_string(); // TODO(mutexlox): Localize these strings. #[cfg(not(target_os = "windows"))] @@ -391,7 +522,7 @@ impl AudioDeviceModule { } else { return Err(anyhow!("Could not get device name")); } - if let Some(id) = d.device_id() { + if let Some(id) = &d.device_id { copy_and_truncate_string(id, guid_out, ADM_MAX_GUID_SIZE)?; } else { return Err(anyhow!("Could not get device ID")); @@ -407,7 +538,7 @@ impl AudioDeviceModule { } pub fn playout_device_name( - &self, + &mut self, index: u16, name_out: webrtc::ptr::Borrowed, guid_out: webrtc::ptr::Borrowed, @@ -430,7 +561,7 @@ impl AudioDeviceModule { } pub fn recording_device_name( - &self, + &mut self, index: u16, name_out: webrtc::ptr::Borrowed, guid_out: webrtc::ptr::Borrowed, @@ -455,26 +586,17 @@ impl AudioDeviceModule { // Device selection pub fn set_playout_device(&mut self, index: u16) -> i32 { let device = match self.enumerate_devices(DeviceType::OUTPUT) { - Ok(devices) => { - for device in devices.iter() { - info!( - "Playout device: ({})", - AudioDeviceModule::device_str(device) + Ok(devices) => match devices.get(index as usize) { + Some(device) => device.devid, + None => { + error!( + "Invalid playout device index {} requested (len {})", + index, + devices.count() ); + return -1; } - - match devices.get(index as usize) { - Some(device) => device.devid(), - None => { - error!( - "Invalid device index {} requested (len {})", - index, - devices.count() - ); - return -1; - } - } - } + }, Err(e) => { error!("failed to enumerate devices for playout device: {}", e); return -1; @@ -495,25 +617,17 @@ impl AudioDeviceModule { pub fn set_recording_device(&mut self, index: u16) -> i32 { let device = match self.enumerate_devices(DeviceType::INPUT) { - Ok(devices) => { - for device in devices.iter() { - info!( - "Recording device: ({})", - AudioDeviceModule::device_str(device) + Ok(devices) => match devices.get(index as usize) { + Some(device) => device.devid, + None => { + error!( + "Invalid recording device index {} requested (len {})", + index, + devices.count() ); + return -1; } - match devices.get(index as usize) { - Some(device) => device.devid(), - None => { - error!( - "Invalid device index {} requested (len {})", - index, - devices.count() - ); - return -1; - } - } - } + }, Err(e) => { error!("failed to enumerate devices for playout device: {}", e); return -1; diff --git a/src/rust/src/webrtc/audio_device_module_utils.rs b/src/rust/src/webrtc/audio_device_module_utils.rs index 90d60970..01fdccbf 100644 --- a/src/rust/src/webrtc/audio_device_module_utils.rs +++ b/src/rust/src/webrtc/audio_device_module_utils.rs @@ -8,51 +8,85 @@ use crate::webrtc; use anyhow::anyhow; -use cubeb::{DeviceCollection, DeviceInfo, DeviceState}; +use cubeb::{DeviceCollection, DeviceState}; use cubeb_core::DevicePref; -use std::ffi::{c_uchar, CString}; +#[cfg(target_os = "linux")] +use cubeb_core::DeviceType; +use std::ffi::{c_uchar, c_void, CString}; + +pub struct MinimalDeviceInfo { + pub devid: *const c_void, + pub device_id: Option, + pub friendly_name: Option, + #[cfg(target_os = "linux")] + device_type: DeviceType, + preferred: DevicePref, + state: DeviceState, +} /// Wrapper struct for DeviceCollection that handles default devices. -pub struct DeviceCollectionWrapper<'a> { - device_collection: DeviceCollection<'a>, +/// +/// Rather than storing the DeviceCollection directly, which raises complex +/// lifetime issues, store just the fields we need. +/// +/// Note that, in some cases, `devid` may be a pointer to state in the cubeb ctx, +/// so in no event should this outlive the associated ctx. +pub struct DeviceCollectionWrapper { + device_collection: Vec, } #[cfg(target_os = "linux")] -fn device_is_monitor(device: &DeviceInfo) -> bool { - device.device_type() == cubeb::DeviceType::INPUT +fn device_is_monitor(device: &MinimalDeviceInfo) -> bool { + device.device_type == DeviceType::INPUT && device - .device_id() + .device_id .as_ref() .map_or(false, |s| s.ends_with(".monitor")) } -impl DeviceCollectionWrapper<'_> { - pub fn new(device_collection: DeviceCollection<'_>) -> DeviceCollectionWrapper<'_> { - DeviceCollectionWrapper { device_collection } +impl DeviceCollectionWrapper { + pub fn new(device_collection: DeviceCollection<'_>) -> DeviceCollectionWrapper { + let mut out = Vec::new(); + for device in device_collection.iter() { + out.push(MinimalDeviceInfo { + devid: device.devid(), + device_id: device.device_id().as_ref().map(|s| s.to_string()), + friendly_name: device.friendly_name().as_ref().map(|s| s.to_string()), + #[cfg(target_os = "linux")] + device_type: device.device_type(), + preferred: device.preferred(), + state: device.state(), + }) + } + DeviceCollectionWrapper { + device_collection: out, + } } /// Iterate over all Enabled devices (those that are plugged in and not disabled by the OS) pub fn iter( &self, - ) -> std::iter::Filter, fn(&&DeviceInfo) -> bool> { + ) -> std::iter::Filter, fn(&&MinimalDeviceInfo) -> bool> + { self.device_collection .iter() - .filter(|d| d.state() == DeviceState::Enabled) + .filter(|d| d.state == DeviceState::Enabled) } // For linux only, a method that will ignore "monitor" devices. #[cfg(target_os = "linux")] pub fn iter_non_monitor( &self, - ) -> std::iter::Filter, fn(&&DeviceInfo) -> bool> { + ) -> std::iter::Filter, fn(&&MinimalDeviceInfo) -> bool> + { self.device_collection .iter() - .filter(|&d| d.state() == DeviceState::Enabled && !device_is_monitor(d)) + .filter(|&d| d.state == DeviceState::Enabled && !device_is_monitor(d)) } #[cfg(target_os = "windows")] /// Get a specified device index, accounting for the two default devices. - pub fn get(&self, idx: usize) -> Option<&DeviceInfo> { + pub fn get(&self, idx: usize) -> Option<&MinimalDeviceInfo> { // 0 should be "default device" and 1 should be "default communications device". // Note: On windows, CUBEB_DEVICE_PREF_VOICE will be set for default communications device, // and CUBEB_DEVICE_PREF_MULTIMEDIA | CUBEB_DEVICE_PREF_NOTIFICATION for default device. @@ -64,17 +98,17 @@ impl DeviceCollectionWrapper<'_> { } else if idx == 1 { // Find a device that's preferred for VOICE -- device 1 is the "default communications" self.iter() - .find(|&device| device.preferred().contains(DevicePref::VOICE)) + .find(|&device| device.preferred.contains(DevicePref::VOICE)) } else { // Find a device that's preferred for MULTIMEDIA -- device 0 is the "default" self.iter() - .find(|&device| device.preferred().contains(DevicePref::MULTIMEDIA)) + .find(|&device| device.preferred.contains(DevicePref::MULTIMEDIA)) } } #[cfg(not(target_os = "windows"))] /// Get a specified device index, accounting for the default device. - pub fn get(&self, idx: usize) -> Option<&DeviceInfo> { + pub fn get(&self, idx: usize) -> Option<&MinimalDeviceInfo> { if self.count() == 0 { None } else if idx > 0 { @@ -92,7 +126,7 @@ impl DeviceCollectionWrapper<'_> { // Even on linux, we do *NOT* filter monitor devices -- if the user specified that as // default, we respect it. self.iter() - .find(|&device| device.preferred().contains(DevicePref::VOICE)) + .find(|&device| device.preferred.contains(DevicePref::VOICE)) } }