-
Notifications
You must be signed in to change notification settings - Fork 53
add support for leading_zeros, trailing_zeros and fix count_ones #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
15f5a5b
leading_zeros: add support for leading_zeros and trailing_zeros, limi…
Firestar99 ce566fa
leading_zeros: add tests for non 32bit integers, failing
Firestar99 70507e9
leading_zeros: support 8bit, 16bit and emulate 64bit
Firestar99 31ce03f
leading_zeros: fix unused warning for `enabled_extensions`
Firestar99 7d3452c
count_ones: add failing tests for count_ones and bit_reverse
Firestar99 5083b9e
count_ones: fix pointer cast errors
Firestar99 4c046f7
count_ones: fix count_ones, must be u32-only in vulkan
Firestar99 6733b1a
count_ones: fix bit_reverse, must be u32-only in vulkan
Firestar99 ef1d3ff
count_ones: fix u64 bitshifts in all new functions
Firestar99 ac49f2f
count_ones: fix mismatched error messages to methods containing them
Firestar99 1449527
count_ones: cargo fmt
Firestar99 57c7d9d
leading_zeros: fix leading zeros for u32
Firestar99 ee1904b
leading_zeros: remove sign conversions, always unsigned, except for l…
Firestar99 ad3ace9
leading_zeros: make all conversions unsigned
Firestar99 00f41df
leading_zeros: code cleanup
Firestar99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,46 +211,15 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> { | |
self.rotate(val, shift, is_left) | ||
} | ||
|
||
// TODO: Do we want to manually implement these instead of using intel instructions? | ||
sym::ctlz | sym::ctlz_nonzero => { | ||
let result = self | ||
.emit() | ||
.u_count_leading_zeros_intel( | ||
args[0].immediate().ty, | ||
None, | ||
args[0].immediate().def(self), | ||
) | ||
.unwrap(); | ||
self.ext_inst | ||
.borrow_mut() | ||
.require_integer_functions_2_intel(self, result); | ||
result.with_type(args[0].immediate().ty) | ||
} | ||
sym::cttz | sym::cttz_nonzero => { | ||
let result = self | ||
.emit() | ||
.u_count_trailing_zeros_intel( | ||
args[0].immediate().ty, | ||
None, | ||
args[0].immediate().def(self), | ||
) | ||
.unwrap(); | ||
self.ext_inst | ||
.borrow_mut() | ||
.require_integer_functions_2_intel(self, result); | ||
result.with_type(args[0].immediate().ty) | ||
sym::ctlz => self.count_leading_trailing_zeros(args[0].immediate(), false, false), | ||
sym::ctlz_nonzero => { | ||
self.count_leading_trailing_zeros(args[0].immediate(), false, true) | ||
} | ||
sym::cttz => self.count_leading_trailing_zeros(args[0].immediate(), true, false), | ||
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true), | ||
|
||
sym::ctpop => self | ||
.emit() | ||
.bit_count(args[0].immediate().ty, None, args[0].immediate().def(self)) | ||
.unwrap() | ||
.with_type(args[0].immediate().ty), | ||
sym::bitreverse => self | ||
.emit() | ||
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self)) | ||
.unwrap() | ||
.with_type(args[0].immediate().ty), | ||
sym::ctpop => self.count_ones(args[0].immediate()), | ||
sym::bitreverse => self.bit_reverse(args[0].immediate()), | ||
sym::bswap => { | ||
// https://github.com/KhronosGroup/SPIRV-LLVM/pull/221/files | ||
// TODO: Definitely add tests to make sure this impl is right. | ||
|
@@ -398,6 +367,254 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> { | |
} | ||
|
||
impl Builder<'_, '_> { | ||
pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue { | ||
let ty = arg.ty; | ||
match self.cx.lookup_type(ty) { | ||
SpirvType::Integer(bits, signed) => { | ||
let u32 = SpirvType::Integer(32, false).def(self.span(), self); | ||
|
||
match bits { | ||
8 | 16 => { | ||
let arg = arg.def(self); | ||
let arg = if signed { | ||
let unsigned = SpirvType::Integer(bits, false).def(self.span(), self); | ||
self.emit().bitcast(unsigned, None, arg).unwrap() | ||
} else { | ||
arg | ||
}; | ||
let arg = self.emit().u_convert(u32, None, arg).unwrap(); | ||
self.emit().bit_count(u32, None, arg).unwrap() | ||
} | ||
32 => self.emit().bit_count(u32, None, arg.def(self)).unwrap(), | ||
64 => { | ||
let u32_32 = self.constant_u32(self.span(), 32).def(self); | ||
let arg = arg.def(self); | ||
let lower = self.emit().s_convert(u32, None, arg).unwrap(); | ||
let higher = self | ||
.emit() | ||
.shift_right_logical(ty, None, arg, u32_32) | ||
.unwrap(); | ||
let higher = self.emit().s_convert(u32, None, higher).unwrap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this need to change depending on sign? |
||
|
||
let lower_bits = self.emit().bit_count(u32, None, lower).unwrap(); | ||
let higher_bits = self.emit().bit_count(u32, None, higher).unwrap(); | ||
self.emit() | ||
.i_add(u32, None, lower_bits, higher_bits) | ||
.unwrap() | ||
} | ||
_ => { | ||
let undef = self.undef(ty).def(self); | ||
self.zombie( | ||
undef, | ||
&format!("count_ones() on unsupported {ty:?} bit integer type"), | ||
); | ||
undef | ||
} | ||
} | ||
.with_type(u32) | ||
} | ||
_ => self.fatal("count_ones() on a non-integer type"), | ||
} | ||
} | ||
|
||
pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue { | ||
let ty = arg.ty; | ||
match self.cx.lookup_type(ty) { | ||
SpirvType::Integer(bits, signed) => { | ||
let u32 = SpirvType::Integer(32, false).def(self.span(), self); | ||
let uint = SpirvType::Integer(bits, false).def(self.span(), self); | ||
|
||
match (bits, signed) { | ||
(8 | 16, signed) => { | ||
let arg = arg.def(self); | ||
let arg = if signed { | ||
self.emit().bitcast(uint, None, arg).unwrap() | ||
} else { | ||
arg | ||
}; | ||
let arg = self.emit().u_convert(u32, None, arg).unwrap(); | ||
|
||
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap(); | ||
let shift = self.constant_u32(self.span(), 32 - bits).def(self); | ||
let reverse = self | ||
.emit() | ||
.shift_right_logical(u32, None, reverse, shift) | ||
.unwrap(); | ||
let reverse = self.emit().u_convert(uint, None, reverse).unwrap(); | ||
if signed { | ||
self.emit().bitcast(ty, None, reverse).unwrap() | ||
} else { | ||
reverse | ||
} | ||
} | ||
(32, false) => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(), | ||
(32, true) => { | ||
let arg = self.emit().bitcast(u32, None, arg.def(self)).unwrap(); | ||
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap(); | ||
self.emit().bitcast(ty, None, reverse).unwrap() | ||
} | ||
(64, signed) => { | ||
let u32_32 = self.constant_u32(self.span(), 32).def(self); | ||
let arg = arg.def(self); | ||
let lower = self.emit().s_convert(u32, None, arg).unwrap(); | ||
let higher = self | ||
.emit() | ||
.shift_right_logical(ty, None, arg, u32_32) | ||
.unwrap(); | ||
let higher = self.emit().s_convert(u32, None, higher).unwrap(); | ||
|
||
// note that higher and lower have swapped | ||
let higher_bits = self.emit().bit_reverse(u32, None, lower).unwrap(); | ||
let lower_bits = self.emit().bit_reverse(u32, None, higher).unwrap(); | ||
|
||
let higher_bits = self.emit().u_convert(uint, None, higher_bits).unwrap(); | ||
let shift = self.constant_u32(self.span(), 32).def(self); | ||
let higher_bits = self | ||
.emit() | ||
.shift_left_logical(uint, None, higher_bits, shift) | ||
.unwrap(); | ||
let lower_bits = self.emit().u_convert(uint, None, lower_bits).unwrap(); | ||
|
||
let result = self | ||
.emit() | ||
.bitwise_or(ty, None, lower_bits, higher_bits) | ||
.unwrap(); | ||
if signed { | ||
self.emit().bitcast(ty, None, result).unwrap() | ||
} else { | ||
result | ||
} | ||
} | ||
_ => { | ||
let undef = self.undef(ty).def(self); | ||
self.zombie( | ||
undef, | ||
&format!("bit_reverse() on unsupported {ty:?} bit integer type"), | ||
); | ||
undef | ||
} | ||
} | ||
.with_type(ty) | ||
} | ||
_ => self.fatal("bit_reverse() on a non-integer type"), | ||
} | ||
} | ||
|
||
pub fn count_leading_trailing_zeros( | ||
&self, | ||
arg: SpirvValue, | ||
trailing: bool, | ||
non_zero: bool, | ||
) -> SpirvValue { | ||
let ty = arg.ty; | ||
match self.cx.lookup_type(ty) { | ||
SpirvType::Integer(bits, signed) => { | ||
let bool = SpirvType::Bool.def(self.span(), self); | ||
let u32 = SpirvType::Integer(32, false).def(self.span(), self); | ||
|
||
let glsl = self.ext_inst.borrow_mut().import_glsl(self); | ||
let find_xsb = |arg| { | ||
if trailing { | ||
self.emit() | ||
.ext_inst(u32, None, glsl, GLOp::FindILsb as u32, [Operand::IdRef( | ||
arg, | ||
)]) | ||
.unwrap() | ||
} else { | ||
// rust is always unsigned, so FindUMsb | ||
let bla = self | ||
.emit() | ||
.ext_inst(u32, None, glsl, GLOp::FindUMsb as u32, [Operand::IdRef( | ||
arg, | ||
)]) | ||
.unwrap(); | ||
// the glsl op returns the Msb bit, not the amount of leading zeros of this u32 | ||
// leading zeros = 31 - Msb bit | ||
let u32_31 = self.constant_u32(self.span(), 31).def(self); | ||
self.emit().i_sub(u32, None, u32_31, bla).unwrap() | ||
} | ||
}; | ||
|
||
let converted = match bits { | ||
8 | 16 => { | ||
if trailing { | ||
let arg = self.emit().s_convert(u32, None, arg.def(self)).unwrap(); | ||
find_xsb(arg) | ||
} else { | ||
let arg = arg.def(self); | ||
let arg = if signed { | ||
let unsigned = | ||
SpirvType::Integer(bits, false).def(self.span(), self); | ||
self.emit().bitcast(unsigned, None, arg).unwrap() | ||
} else { | ||
arg | ||
}; | ||
let arg = self.emit().u_convert(u32, None, arg).unwrap(); | ||
let xsb = find_xsb(arg); | ||
let subtrahend = self.constant_u32(self.span(), 32 - bits).def(self); | ||
self.emit().i_sub(u32, None, xsb, subtrahend).unwrap() | ||
} | ||
} | ||
32 => find_xsb(arg.def(self)), | ||
64 => { | ||
let u32_0 = self.constant_int(u32, 0).def(self); | ||
let u32_32 = self.constant_u32(self.span(), 32).def(self); | ||
|
||
let arg = arg.def(self); | ||
let lower = self.emit().s_convert(u32, None, arg).unwrap(); | ||
let higher = self | ||
.emit() | ||
.shift_right_logical(ty, None, arg, u32_32) | ||
.unwrap(); | ||
let higher = self.emit().s_convert(u32, None, higher).unwrap(); | ||
|
||
let lower_bits = find_xsb(lower); | ||
let higher_bits = find_xsb(higher); | ||
|
||
if trailing { | ||
let use_lower = self.emit().i_equal(bool, None, higher, u32_0).unwrap(); | ||
let lower_bits = | ||
self.emit().i_add(u32, None, lower_bits, u32_32).unwrap(); | ||
self.emit() | ||
.select(u32, None, use_lower, lower_bits, higher_bits) | ||
.unwrap() | ||
} else { | ||
let use_higher = self.emit().i_equal(bool, None, lower, u32_0).unwrap(); | ||
let higher_bits = | ||
self.emit().i_add(u32, None, higher_bits, u32_32).unwrap(); | ||
self.emit() | ||
.select(u32, None, use_higher, higher_bits, lower_bits) | ||
.unwrap() | ||
} | ||
} | ||
_ => { | ||
let undef = self.undef(ty).def(self); | ||
self.zombie(undef, &format!( | ||
"count_leading_trailing_zeros() on unsupported {ty:?} bit integer type" | ||
)); | ||
undef | ||
} | ||
}; | ||
|
||
if non_zero { | ||
converted | ||
} else { | ||
let int_0 = self.constant_int(ty, 0).def(self); | ||
let int_bits = self.constant_int(u32, bits as u128).def(self); | ||
let is_0 = self | ||
.emit() | ||
.i_equal(bool, None, arg.def(self), int_0) | ||
.unwrap(); | ||
self.emit() | ||
.select(u32, None, is_0, int_bits, converted) | ||
.unwrap() | ||
} | ||
.with_type(u32) | ||
} | ||
_ => self.fatal("count_leading_trailing_zeros() on a non-integer type"), | ||
} | ||
} | ||
|
||
pub fn abort_with_kind_and_message_debug_printf( | ||
&mut self, | ||
kind: &str, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a style thing, but I generally feel it is clearer to matched on
signed
as well, it makes all the cases clearer, similar to what you did in bitcast.