diff --git a/src/passthrough/mod.rs b/src/passthrough/mod.rs index 32efee1f7..0c27c6217 100644 --- a/src/passthrough/mod.rs +++ b/src/passthrough/mod.rs @@ -22,8 +22,8 @@ use std::os::fd::{AsFd, BorrowedFd}; use std::os::unix::ffi::OsStringExt; use std::os::unix::io::{AsRawFd, RawFd}; use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockWriteGuard}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, RwLock, RwLockWriteGuard}; use std::time::Duration; use vm_memory::{bitmap::BitmapSlice, ByteValued}; @@ -35,7 +35,7 @@ use self::mount_fd::MountFds; use self::statx::{statx, StatExt}; use self::util::{ ebadf, einval, enosys, eperm, is_dir, is_safe_inode, openat, reopen_fd_through_proc, stat_fd, - UniqueInodeGenerator, + FileFlagGuard, UniqueInodeGenerator, }; use crate::abi::fuse_abi as fuse; use crate::abi::fuse_abi::Opcode; @@ -256,8 +256,7 @@ impl InodeMap { struct HandleData { inode: Inode, file: File, - lock: Mutex<()>, - open_flags: AtomicU32, + open_flags: RwLock, } impl HandleData { @@ -265,8 +264,7 @@ impl HandleData { HandleData { inode, file, - lock: Mutex::new(()), - open_flags: AtomicU32::new(flags), + open_flags: RwLock::new(flags), } } @@ -274,21 +272,14 @@ impl HandleData { &self.file } - fn get_file_mut(&self) -> (MutexGuard<()>, &File) { - (self.lock.lock().unwrap(), &self.file) + fn get_file_mut(&self) -> (FileFlagGuard, &File) { + let guard = self.open_flags.write().unwrap(); + (FileFlagGuard::Writer(guard), &self.file) } fn borrow_fd(&self) -> BorrowedFd { self.file.as_fd() } - - fn get_flags(&self) -> u32 { - self.open_flags.load(Ordering::Relaxed) - } - - fn set_flags(&self, flags: u32) { - self.open_flags.store(flags, Ordering::Relaxed); - } } struct HandleMap { diff --git a/src/passthrough/sync_io.rs b/src/passthrough/sync_io.rs index 97d24c563..d3448b377 100644 --- a/src/passthrough/sync_io.rs +++ b/src/passthrough/sync_io.rs @@ -43,16 +43,34 @@ impl PassthroughFs { /// if these do not match update the file descriptor flags and store the new /// result in the HandleData entry #[inline(always)] - fn check_fd_flags(&self, data: Arc, fd: RawFd, flags: u32) -> io::Result<()> { - let open_flags = data.get_flags(); - if open_flags != flags { - let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags) }; + fn ensure_file_flags<'a>( + &self, + data: &'a Arc, + fd: &impl AsRawFd, + mut flags: u32, + ) -> io::Result> { + let guard = data.open_flags.read().unwrap(); + if *guard & libc::O_DIRECT as u32 == flags & libc::O_DIRECT as u32 { + return Ok(FileFlagGuard::Reader(guard)); + } + drop(guard); + + let mut guard = data.open_flags.write().unwrap(); + // Update the O_DIRECT flag if needed + if *guard & libc::O_DIRECT as u32 != flags & libc::O_DIRECT as u32 { + if flags & libc::O_DIRECT as u32 != 0 { + flags = *guard | libc::O_DIRECT as u32; + } else { + flags = *guard & !libc::O_DIRECT as u32; + } + let ret = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_SETFL, flags) }; if ret != 0 { return Err(io::Error::last_os_error()); } - data.set_flags(flags); + *guard = flags; } - Ok(()) + + Ok(FileFlagGuard::Writer(guard)) } fn do_readdir( @@ -116,8 +134,7 @@ impl PassthroughFs { let (front, back) = rem.split_at(size_of::()); - let dirent64 = LinuxDirent64::from_slice(front) - .expect("fuse: unable to get LinuxDirent64 from slice"); + let dirent64 = LinuxDirent64::from_slice(front).ok_or_else(einval)?; let namelen = dirent64.d_reclen as usize - size_of::(); debug_assert!( @@ -659,16 +676,14 @@ impl FileSystem for PassthroughFs { flags: u32, ) -> io::Result { let data = self.get_data(handle, inode, libc::O_RDONLY)?; + let fd = data.borrow_fd(); + + self.ensure_file_flags(&data, &fd, flags)?; // Manually implement File::try_clone() by borrowing fd of data.file instead of dup(). // It's safe because the `data` variable's lifetime spans the whole function, // so data.file won't be closed. - let f = unsafe { File::from_raw_fd(data.borrow_fd().as_raw_fd()) }; - - self.check_fd_flags(data, f.as_raw_fd(), flags)?; - - let mut f = ManuallyDrop::new(f); - + let mut f = unsafe { ManuallyDrop::new(File::from_raw_fd(fd.as_raw_fd())) }; w.write_from(&mut *f, size as usize, offset) } @@ -686,21 +701,14 @@ impl FileSystem for PassthroughFs { fuse_flags: u32, ) -> io::Result { let data = self.get_data(handle, inode, libc::O_RDWR)?; + let fd = data.borrow_fd(); - // Manually implement File::try_clone() by borrowing fd of data.file instead of dup(). - // It's safe because the `data` variable's lifetime spans the whole function, - // so data.file won't be closed. - let f = unsafe { File::from_raw_fd(data.borrow_fd().as_raw_fd()) }; - - self.check_fd_flags(data, f.as_raw_fd(), flags)?; - + self.ensure_file_flags(&data, &fd, flags)?; if self.seal_size.load(Ordering::Relaxed) { - let st = stat_fd(&f, None)?; + let st = stat_fd(&fd, None)?; self.seal_size_check(Opcode::Write, st.st_size as u64, offset, size as u64, 0)?; } - let mut f = ManuallyDrop::new(f); - // Cap restored when _killpriv is dropped let _killpriv = if self.killpriv_v2.load(Ordering::Relaxed) && (fuse_flags & WRITE_KILL_PRIV != 0) { @@ -709,6 +717,10 @@ impl FileSystem for PassthroughFs { None }; + // Manually implement File::try_clone() by borrowing fd of data.file instead of dup(). + // It's safe because the `data` variable's lifetime spans the whole function, + // so data.file won't be closed. + let mut f = unsafe { ManuallyDrop::new(File::from_raw_fd(fd.as_raw_fd())) }; r.read_to(&mut *f, size as usize, offset) } diff --git a/src/passthrough/util.rs b/src/passthrough/util.rs index 08a31031f..172bee861 100644 --- a/src/passthrough/util.rs +++ b/src/passthrough/util.rs @@ -7,9 +7,10 @@ use std::ffi::{CStr, CString}; use std::fs::File; use std::io; use std::mem::MaybeUninit; +use std::ops::Deref; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::sync::atomic::{AtomicU64, AtomicU8, Ordering}; -use std::sync::Mutex; +use std::sync::{Mutex, RwLockReadGuard, RwLockWriteGuard}; use super::inode_store::InodeId; use super::MAX_HOST_INO; @@ -225,6 +226,23 @@ pub fn eperm() -> io::Error { io::Error::from_raw_os_error(libc::EPERM) } +/// A helper structure to hold RwLock guard. +pub enum FileFlagGuard<'a, T> { + Reader(RwLockReadGuard<'a, T>), + Writer(RwLockWriteGuard<'a, T>), +} + +impl<'a, T> Deref for FileFlagGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + FileFlagGuard::Reader(v) => v.deref(), + FileFlagGuard::Writer(v) => v.deref(), + } + } +} + #[cfg(test)] mod tests { use super::*;