Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Improve sqrt and sqrtf by using hardware more often. #222

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# `libm`

[![crates.io](https://img.shields.io/crates/v/libm.svg)](https://crates.io/crates/libm)
[![docs.rs](https://docs.rs/libm/badge.svg)](https://docs.rs/libm/)
[![Build Status](https://dev.azure.com/rust-lang/libm/_apis/build/status/rust-lang-nursery.libm?branchName=master)](https://dev.azure.com/rust-lang/libm/_build/latest?definitionId=7&branchName=master)

A port of [MUSL]'s libm to Rust.
Expand Down
7 changes: 3 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! libm in pure Rust
#![deny(warnings)]
#![no_std]
#![cfg_attr(
all(target_arch = "wasm32", not(feature = "stable")),
feature(core_intrinsics)
)]
#![cfg_attr(not(feature = "stable"), feature(core_intrinsics))]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::many_single_char_names)]

mod math;

Expand Down
300 changes: 158 additions & 142 deletions src/math/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,53 @@
/// [Square root](https://en.wikipedia.org/wiki/Square_root) of an `f64`.
///
/// This function is intended to exactly match the
/// [`sqrt`](https://en.cppreference.com/w/c/numeric/math/sqrt) function as
/// defined by the C/C++ spec.
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
#[allow(unreachable_code)]
pub fn sqrt(x: f64) -> f64 {
// On most targets LLVM will issue a hardware sqrt instruction instead of a
// call to this function when you call `core::intrinsics::sqrtf64(x)`.
// However, not all targets have hardware sqrt, and also people might end up
// calling this function directly, so we must still do our best to get a
// hardware instruction used, and then go with a software fallback if
// necessary.

// Nightly: if we're *sure* that LLVM supports a hardware instruction then
// we'll call the LLVM intrinsic. We have to be conservative in our
// selection for when to do this, because if the intrinsic usage ends up
// calling back here it's infinite recursion.
#[cfg(all(
not(feature = "stable"),
any(
all(target_arch = "x86", not(target_feature = "soft_float")),
all(target_arch = "x86_64", not(target_feature = "soft_float")),
target_arch = "aarch64",
target_arch = "wasm32",
target_arch = "powerpc64"
)
))]
{
return unsafe { core::intrinsics::sqrtf64(x) };
}
// Stable: We can use `sse2` if available. As more intrinsic sets stabilize
// we can expand this to use hardware more often.
#[cfg(target_feature = "sse2")]
{
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
return unsafe {
let m = _mm_set_sd(x);
let m_sqrt = _mm_sqrt_pd(m);
_mm_cvtsd_f64(m_sqrt)
};
}
// Finally, if we must, we use the software version (below).
software_sqrt(x)
}

/* origin: FreeBSD /usr/src/lib/msun/src/e_sqrt.c */
/*
* ====================================================
Expand Down Expand Up @@ -75,167 +125,133 @@
* sqrt(-ve) = NaN ... with invalid signal
* sqrt(NaN) = NaN ... with invalid signal for signaling NaN
*/
fn software_sqrt(x: f64) -> f64 {
use core::num::Wrapping;

use core::f64;

#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn sqrt(x: f64) -> f64 {
// On wasm32 we know that LLVM's intrinsic will compile to an optimized
// `f64.sqrt` native instruction, so we can leverage this for both code size
// and speed.
llvm_intrinsically_optimized! {
#[cfg(target_arch = "wasm32")] {
return if x < 0.0 {
f64::NAN
} else {
unsafe { ::core::intrinsics::sqrtf64(x) }
}
}
}
#[cfg(target_feature = "sse2")]
{
// Note: This path is unlikely since LLVM will usually have already
// optimized sqrt calls into hardware instructions if sse2 is available,
// but if someone does end up here they'll apprected the speed increase.
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
unsafe {
let m = _mm_set_sd(x);
let m_sqrt = _mm_sqrt_pd(m);
_mm_cvtsd_f64(m_sqrt)
}
}
#[cfg(not(target_feature = "sse2"))]
{
use core::num::Wrapping;

const TINY: f64 = 1.0e-300;
const TINY: f64 = 1.0e-300;

let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;
let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;

ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);
ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);

/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
}
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
}
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
}
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
}
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
}
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
}
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
}
m >>= 1; /* m = [m/2] */

/* generate sqrt(x) bit by bit */
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */
}
m >>= 1; /* m = [m/2] */

while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
/* generate sqrt(x) bit by bit */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */

while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
}
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
}
ix1 -= t1;
q1 += r;
r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
}
ix1 -= t1;
q1 += r;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
q += 1;
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}
Loading