Skip to content

Commit 9cf1a9e

Browse files
committed
leading_zeros: remove sign conversions, always unsigned, except for leading_zeros
1 parent 0df98d7 commit 9cf1a9e

File tree

1 file changed

+38
-47
lines changed

1 file changed

+38
-47
lines changed

crates/rustc_codegen_spirv/src/builder/intrinsics.rs

+38-47
Original file line numberDiff line numberDiff line change
@@ -370,18 +370,12 @@ impl Builder<'_, '_> {
370370
pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue {
371371
let ty = arg.ty;
372372
match self.cx.lookup_type(ty) {
373-
SpirvType::Integer(bits, signed) => {
373+
SpirvType::Integer(bits, false) => {
374374
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
375375

376376
match bits {
377377
8 | 16 => {
378378
let arg = arg.def(self);
379-
let arg = if signed {
380-
let unsigned = SpirvType::Integer(bits, false).def(self.span(), self);
381-
self.emit().bitcast(unsigned, None, arg).unwrap()
382-
} else {
383-
arg
384-
};
385379
let arg = self.emit().u_convert(u32, None, arg).unwrap();
386380
self.emit().bit_count(u32, None, arg).unwrap()
387381
}
@@ -413,25 +407,23 @@ impl Builder<'_, '_> {
413407
}
414408
.with_type(u32)
415409
}
416-
_ => self.fatal("count_ones() on a non-integer type"),
410+
_ => self.fatal(format!(
411+
"count_ones() expected an unsigned integer type, got {:?}",
412+
self.cx.lookup_type(ty)
413+
)),
417414
}
418415
}
419416

420417
pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue {
421418
let ty = arg.ty;
422419
match self.cx.lookup_type(ty) {
423-
SpirvType::Integer(bits, signed) => {
420+
SpirvType::Integer(bits, false) => {
424421
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
425422
let uint = SpirvType::Integer(bits, false).def(self.span(), self);
426423

427-
match (bits, signed) {
428-
(8 | 16, signed) => {
424+
match bits {
425+
8 | 16 => {
429426
let arg = arg.def(self);
430-
let arg = if signed {
431-
self.emit().bitcast(uint, None, arg).unwrap()
432-
} else {
433-
arg
434-
};
435427
let arg = self.emit().u_convert(u32, None, arg).unwrap();
436428

437429
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
@@ -440,20 +432,10 @@ impl Builder<'_, '_> {
440432
.emit()
441433
.shift_right_logical(u32, None, reverse, shift)
442434
.unwrap();
443-
let reverse = self.emit().u_convert(uint, None, reverse).unwrap();
444-
if signed {
445-
self.emit().bitcast(ty, None, reverse).unwrap()
446-
} else {
447-
reverse
448-
}
449-
}
450-
(32, false) => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(),
451-
(32, true) => {
452-
let arg = self.emit().bitcast(u32, None, arg.def(self)).unwrap();
453-
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
454-
self.emit().bitcast(ty, None, reverse).unwrap()
435+
self.emit().u_convert(uint, None, reverse).unwrap()
455436
}
456-
(64, signed) => {
437+
32 => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(),
438+
64 => {
457439
let u32_32 = self.constant_u32(self.span(), 32).def(self);
458440
let arg = arg.def(self);
459441
let lower = self.emit().s_convert(u32, None, arg).unwrap();
@@ -475,15 +457,9 @@ impl Builder<'_, '_> {
475457
.unwrap();
476458
let lower_bits = self.emit().u_convert(uint, None, lower_bits).unwrap();
477459

478-
let result = self
479-
.emit()
460+
self.emit()
480461
.bitwise_or(ty, None, lower_bits, higher_bits)
481-
.unwrap();
482-
if signed {
483-
self.emit().bitcast(ty, None, result).unwrap()
484-
} else {
485-
result
486-
}
462+
.unwrap()
487463
}
488464
_ => {
489465
let undef = self.undef(ty).def(self);
@@ -496,7 +472,10 @@ impl Builder<'_, '_> {
496472
}
497473
.with_type(ty)
498474
}
499-
_ => self.fatal("bit_reverse() on a non-integer type"),
475+
_ => self.fatal(format!(
476+
"bit_reverse() expected an unsigned integer type, got {:?}",
477+
self.cx.lookup_type(ty)
478+
)),
500479
}
501480
}
502481

@@ -508,7 +487,7 @@ impl Builder<'_, '_> {
508487
) -> SpirvValue {
509488
let ty = arg.ty;
510489
match self.cx.lookup_type(ty) {
511-
SpirvType::Integer(bits, signed) => {
490+
SpirvType::Integer(bits, false) => {
512491
let bool = SpirvType::Bool.def(self.span(), self);
513492
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
514493

@@ -542,13 +521,6 @@ impl Builder<'_, '_> {
542521
find_xsb(arg)
543522
} else {
544523
let arg = arg.def(self);
545-
let arg = if signed {
546-
let unsigned =
547-
SpirvType::Integer(bits, false).def(self.span(), self);
548-
self.emit().bitcast(unsigned, None, arg).unwrap()
549-
} else {
550-
arg
551-
};
552524
let arg = self.emit().u_convert(u32, None, arg).unwrap();
553525
let xsb = find_xsb(arg);
554526
let subtrahend = self.constant_u32(self.span(), 32 - bits).def(self);
@@ -611,7 +583,26 @@ impl Builder<'_, '_> {
611583
}
612584
.with_type(u32)
613585
}
614-
_ => self.fatal("count_leading_trailing_zeros() on a non-integer type"),
586+
SpirvType::Integer(bits, true) => {
587+
// rustc wants `[i8,i16,i32,i64]::leading_zeros()` with `non_zero: true` for some reason. I do not know
588+
// how these are reachable, marking them as zombies makes none of our compiletests fail.
589+
let unsigned = SpirvType::Integer(bits, false).def(self.span(), self);
590+
let arg = self
591+
.emit()
592+
.bitcast(unsigned, None, arg.def(self))
593+
.unwrap()
594+
.with_type(unsigned);
595+
let result = self.count_leading_trailing_zeros(arg, trailing, non_zero);
596+
self.emit()
597+
.bitcast(ty, None, result.def(self))
598+
.unwrap()
599+
.with_type(ty)
600+
}
601+
e => {
602+
self.fatal(format!(
603+
"count_leading_trailing_zeros(trailing: {trailing}, non_zero: {non_zero}) expected an integer type, got {e:?}",
604+
));
605+
}
615606
}
616607
}
617608

0 commit comments

Comments
 (0)