Skip to content

Commit 5a91766

Browse files
committed
Add mul_hi function for bit integers
Move the _mul_high function from base/multinverses.jl to base/int.jl. Rename it to mul_hi.
1 parent 99fd5d9 commit 5a91766

File tree

2 files changed

+41
-23
lines changed

2 files changed

+41
-23
lines changed

base/int.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,44 @@ inv(x::Integer) = float(one(x)) / float(x)
9696
# skip promotion for system integer types
9797
(/)(x::BitInteger, y::BitInteger) = float(x) / float(y)
9898

99+
100+
"""
101+
mul_hi(a::T, b::T) where {T<:Integer}
102+
103+
Returns the higher half of the product of `a` and `b`.
104+
105+
# Examples
106+
```jldoctest
107+
julia> mul_hi(12345678987654321, 123456789)
108+
82624
109+
110+
julia> (widen(12345678987654321) * 123456789) >> 64
111+
82624
112+
113+
julia> mul_hi(0xff, 0xff)
114+
0xfe
115+
```
116+
"""
117+
function mul_hi(a::T, b::T) where {T<:BitInteger}
118+
((widen(a)*b) >>> (sizeof(a)*8)) % T
119+
end
120+
121+
function mul_hi(a::UInt128, b::UInt128)
122+
shift = sizeof(a)*4
123+
mask = typemax(UInt128) >> shift
124+
a1, a2 = a >>> shift, a & mask
125+
b1, b2 = b >>> shift, b & mask
126+
a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2
127+
carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift
128+
a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry
129+
end
130+
131+
function mul_hi(a::Int128, b::Int128)
132+
shift = sizeof(a)*8 - 1
133+
t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128
134+
(mul_hi(a % UInt128, b % UInt128) - t1 - t2) % Int128
135+
end
136+
99137
"""
100138
isodd(x::Number) -> Bool
101139

base/multinverses.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module MultiplicativeInverses
44

5-
import Base: div, divrem, rem, unsigned
5+
import Base: div, divrem, mul_hi, rem, unsigned
66
using Base: IndexLinear, IndexCartesian, tail
77
export multiplicativeinverse
88

@@ -134,33 +134,13 @@ struct UnsignedMultiplicativeInverse{T<:Unsigned} <: MultiplicativeInverse{T}
134134
end
135135
UnsignedMultiplicativeInverse(x::Unsigned) = UnsignedMultiplicativeInverse{typeof(x)}(x)
136136

137-
# Returns the higher half of the product a*b
138-
function _mul_high(a::T, b::T) where {T<:Union{Signed, Unsigned}}
139-
((widen(a)*b) >>> (sizeof(a)*8)) % T
140-
end
141-
142-
function _mul_high(a::UInt128, b::UInt128)
143-
shift = sizeof(a)*4
144-
mask = typemax(UInt128) >> shift
145-
a1, a2 = a >>> shift, a & mask
146-
b1, b2 = b >>> shift, b & mask
147-
a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2
148-
carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift
149-
a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry
150-
end
151-
function _mul_high(a::Int128, b::Int128)
152-
shift = sizeof(a)*8 - 1
153-
t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128
154-
(_mul_high(a % UInt128, b % UInt128) - t1 - t2) % Int128
155-
end
156-
157137
function div(a::T, b::SignedMultiplicativeInverse{T}) where T
158-
x = _mul_high(a, b.multiplier)
138+
x = mul_hi(a, b.multiplier)
159139
x += (a*b.addmul) % T
160140
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % T)
161141
end
162142
function div(a::T, b::UnsignedMultiplicativeInverse{T}) where T
163-
x = _mul_high(a, b.multiplier)
143+
x = mul_hi(a, b.multiplier)
164144
x = ifelse(b.add, convert(T, convert(T, (convert(T, a - x) >>> 1)) + x), x)
165145
ifelse(b.divisor == 1, a, x >>> b.shift)
166146
end

0 commit comments

Comments
 (0)