Skip to content

Commit

Permalink
Add support for mbind, get_mempolicy and set_mempolicy (bytecod…
Browse files Browse the repository at this point in the history
…ealliance#938)

This adds support for the `mbind`, `set_mempolicy` and `get_mempolicy`
NUMA syscalls.  The `get_mempolicy` syscall has a few different modes
of operation, depending on the flags, which is demultiplexed into
`get_mempolicy_node` and `get_mempolicy_next_node` for now.  There's a
couple of other modes that writes into the variable length bit array,
which aren't implemented for now.
  • Loading branch information
krh committed Nov 20, 2023
1 parent 496792e commit 6fe1aa4
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ termios = []
# Enable `rustix::mm::*`.
mm = []

# Enable `rustix::numa::*`.
numa = []

# Enable `rustix::pipe::*`.
pipe = []

Expand All @@ -194,6 +197,7 @@ all-apis = [
"mm",
"mount",
"net",
"numa",
"param",
"pipe",
"process",
Expand Down
16 changes: 16 additions & 0 deletions src/backend/linux_raw/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From<Option<crate::net::Protocol>> for ArgReg<'a, Num>
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::Mode> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::Mode) -> Self {
c_uint(flags.bits())
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::ModeFlags> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::ModeFlags) -> Self {
c_uint(flags.bits())
}
}

impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit<T>> for ArgReg<'a, Num> {
#[inline]
fn from(t: &'a mut MaybeUninit<T>) -> Self {
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub(crate) mod mount;
pub(crate) mod mount; // for deprecated mount functions in "fs"
#[cfg(feature = "net")]
pub(crate) mod net;
#[cfg(feature = "numa")]
pub(crate) mod numa;
#[cfg(any(
feature = "param",
feature = "process",
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod syscalls;
pub(crate) mod types;
87 changes: 87 additions & 0 deletions src/backend/linux_raw/numa/syscalls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! linux_raw syscalls supporting `rustix::numa`.
//!
//! # Safety
//!
//! See the `rustix::backend` module documentation for details.
#![allow(unsafe_code)]
#![allow(clippy::undocumented_unsafe_blocks)]

use super::types::{Mode, ModeFlags};

use core::ptr::null_mut;
use core::mem::MaybeUninit;
use crate::backend::c;
use crate::backend::conv::{c_uint, no_fd, pass_usize, ret, ret_owned_fd, ret_void_star, zero};
use crate::io;

/// # Safety
///
/// `mbind` is primarily unsafe due to the `addr` parameter, as anything
/// working with memory pointed to by raw pointers is unsafe.
#[inline]
pub(crate) unsafe fn mbind(addr: *mut c::c_void, length: usize, mode: Mode, nodemask: &[u64], flags: ModeFlags) -> io::Result<()> {
ret(syscall!(
__NR_mbind,
addr,
pass_usize(length),
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize),
flags
))
}

/// # Safety
///
/// `set_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
ret(syscall!(
__NR_set_mempolicy,
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize)
))
}

/// # Safety
///
/// `get_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
addr,
c_uint(linux_raw_sys::general::MPOL_F_NODE | linux_raw_sys::general::MPOL_F_ADDR)
))?;

Ok(mode.assume_init())
}

#[inline]
pub(crate) fn get_mempolicy_next_node() -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

unsafe {
ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
zero(),
c_uint(linux_raw_sys::general::MPOL_F_NODE)
))?;

Ok(mode.assume_init())
}
}
52 changes: 52 additions & 0 deletions src/backend/linux_raw/numa/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use bitflags::bitflags;

bitflags! {
/// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Mode: u32 {
/// `MPOL_F_STATIC_NODES`
const STATIC_NODES = linux_raw_sys::general::MPOL_F_STATIC_NODES;
/// `MPOL_F_RELATIVE_NODES`
const RELATIVE_NODES = linux_raw_sys::general::MPOL_F_RELATIVE_NODES;
/// `MPOL_F_NUMA_BALANCING`
const NUMA_BALANCING = linux_raw_sys::general::MPOL_F_NUMA_BALANCING;

/// `MPOL_DEFAULT`
const DEFAULT = linux_raw_sys::general::MPOL_DEFAULT as u32;
/// `MPOL_PREFERRED`
const PREFERRED = linux_raw_sys::general::MPOL_PREFERRED as u32;
/// `MPOL_BIND`
const BIND = linux_raw_sys::general::MPOL_BIND as u32;
/// `MPOL_INTERLEAVE`
const INTERLEAVE = linux_raw_sys::general::MPOL_INTERLEAVE as u32;
/// `MPOL_LOCAL`
const LOCAL = linux_raw_sys::general::MPOL_LOCAL as u32;
/// `MPOL_PREFERRED_MANY`
const PREFERRED_MANY = linux_raw_sys::general::MPOL_PREFERRED_MANY as u32;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}

bitflags! {
/// `MPOL_MF_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ModeFlags: u32 {
/// `MPOL_MF_STRICT`
const STRICT = linux_raw_sys::general::MPOL_MF_STRICT;
/// `MPOL_MF_MOVE`
const MOVE = linux_raw_sys::general::MPOL_MF_MOVE;
/// `MPOL_MF_MOVE_ALL`
const MOVE_ALL = linux_raw_sys::general::MPOL_MF_MOVE_ALL;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ pub mod mount;
#[cfg(feature = "net")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "net")))]
pub mod net;
#[cfg(linux_kernel)]
#[cfg(feature = "numa")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))]
pub mod numa;
#[cfg(not(any(windows, target_os = "espidf")))]
#[cfg(feature = "param")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "param")))]
Expand Down
104 changes: 104 additions & 0 deletions src/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//! The `numa` API.
//!
//! # Safety
//!
//! `mbind` and related functions manipulate raw pointers and have special
//! semantics and are wildly unsafe.
#![allow(unsafe_code)]

use crate::{backend, io};
use core::ffi::c_void;

pub use backend::numa::types::{Mode, ModeFlags};

/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn mbind(addr: *mut c_void, len: usize, mode: Mode, nodemask: &[u64], flags: ModeFlags) -> io::Result<()> {
backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags)
}


/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for
/// a thread and its children.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
backend::numa::syscalls::set_mempolicy(mode, nodemask)
}

/// `get_mempolicy_node(addr)`-Return the node ID of the node on which
/// the address addr is allocated.
///
/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
/// get_mempolicy() will return the node ID of the node on which the
/// address addr is allocated into the location pointed to by mode.
/// If no page has yet been allocated for the specified address,
/// get_mempolicy() will allocate a page as if the thread had
/// performed a read (load) access to that address, and return the ID
/// of the node where that page was allocated.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_node(addr)
}

/// `get_mempolicy_next_node(addr)`-Return node ID of the next node
/// that will be used for interleaving of internal kernel pages
/// allocated on behalf of the thread.
///
/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the
/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy()
/// will return in the location pointed to by a non-NULL mode
/// argument, the node ID of the next node that will be used for
/// interleaving of internal kernel pages allocated on behalf of the
/// thread. These allocations include pages for memory-mapped files
/// in process memory ranges mapped using the mmap(2) call with the
/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped
/// with the MAP_SHARED flag for all accesses.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_next_node() -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_next_node()
}

31 changes: 31 additions & 0 deletions tests/numa/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#[cfg(all(feature = "mm", feature = "fs"))]
#[test]
fn test_mbind() {
let size = 8192;

unsafe {
let vaddr = rustix::mm::mmap_anonymous(
std::ptr::null_mut(),
size,
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
rustix::mm::MapFlags::PRIVATE,
).unwrap();

vaddr.cast::<usize>().write(100);

let mask = &[1];
rustix::numa::mbind(vaddr, size, rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
mask, rustix::numa::ModeFlags::empty()).unwrap();

rustix::numa::get_mempolicy_node(vaddr).unwrap();

match rustix::numa::get_mempolicy_next_node() {
Err(rustix::io::Errno::INVAL) => (),
_ => panic!("rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT")
}

rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap();

rustix::numa::get_mempolicy_next_node().unwrap();
}
}
25 changes: 25 additions & 0 deletions tests/numa/main.rs~
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#[cfg(feature = "numa")]
#[test]
fn test_mbind() {
let size = 8192;
let fd = rustix::fs::memfd_create(
"memfd",
rustix::fs::MemfdFlags::CLOEXEC
| rustix::fs::MemfdFlags::ALLOW_SEALING,
).unwarp()

rustix::fs::ftruncate(&fd, size as u64).unwrap()

let vaddr = rustix::mm::mmap(
std::ptr::null_mut(),
size,
rustix::mm::ProtFlags::empty(),
rustix::mm::MapFlags::SHARED,
&fd,
0,
)?;

let mask = &[1_usize];
rustix::numa::mbind(vaddr, size, rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
&mask, rustix::numa::ModeFlags::default()).unwrap();
}
Loading

0 comments on commit 6fe1aa4

Please sign in to comment.