Skip to content

Commit 37a0cff

Browse files
committed
factorize count leading and trailing zeros code
1 parent e4fd312 commit 37a0cff

File tree

1 file changed

+121
-176
lines changed

1 file changed

+121
-176
lines changed

src/intrinsic/mod.rs

Lines changed: 121 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -848,8 +848,8 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
848848
self.gcc_int_cast(result, result_type)
849849
}
850850

851-
fn count_leading_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
852-
// if arg is 0, early return 0, else call count_leading_zeroes_nonzero to compute leading zeros
851+
fn count_zeroes(&mut self, width: u64, arg: RValue<'gcc>, count_leading: bool) -> RValue<'gcc> {
852+
// if arg is 0, early return 0, else call count_leading_zeroes_nonzero or count_trailing_zeroes_nonzero
853853
let func = self.current_func();
854854
let then_block = func.new_block("then");
855855
let else_block = func.new_block("else");
@@ -864,11 +864,15 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
864864
then_block.add_assignment(None, result, zero_result);
865865
then_block.end_with_jump(None, after_block);
866866

867-
// NOTE: since jumps were added in a place count_leading_zeroes_nonzero() does not expect,
867+
// NOTE: since jumps were added in a place count_xxxxing_zeroes_nonzero() does not expect,
868868
// the current block in the state need to be updated.
869869
self.switch_to_block(else_block);
870870

871-
let zeros = self.count_leading_zeroes_nonzero(width, arg);
871+
let zeros = if count_leading {
872+
self.count_leading_zeroes_nonzero(width, arg)
873+
} else {
874+
self.count_trailing_zeroes_nonzero(width, arg)
875+
};
872876
self.llbb().add_assignment(None, result, zeros);
873877
self.llbb().end_with_jump(None, after_block);
874878

@@ -879,7 +883,29 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
879883
result.to_rvalue()
880884
}
881885

882-
fn count_leading_zeroes_nonzero(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
886+
fn count_zeroes_nonzero(
887+
&mut self,
888+
width: u64,
889+
arg: RValue<'gcc>,
890+
count_leading: bool,
891+
) -> RValue<'gcc> {
892+
fn use_builtin_function<'a, 'gcc, 'tcx>(
893+
builder: &mut Builder<'a, 'gcc, 'tcx>,
894+
builtin: &str,
895+
arg: RValue<'gcc>,
896+
arg_type: Type<'gcc>,
897+
expected_type: Type<'gcc>,
898+
) -> RValue<'gcc> {
899+
let arg = if arg_type != expected_type {
900+
builder.context.new_cast(builder.location, arg, expected_type)
901+
} else {
902+
arg
903+
};
904+
let builtin = builder.context.get_builtin_function(builtin);
905+
let res = builder.context.new_call(builder.location, builtin, &[arg]);
906+
builder.context.new_cast(builder.location, res, builder.u32_type)
907+
}
908+
883909
// TODO(antoyo): use width?
884910
let result_type = self.u32_type;
885911
let mut arg_type = arg.get_type();
@@ -889,186 +915,105 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
889915
} else {
890916
arg
891917
};
892-
let count_leading_zeroes =
893-
// TODO(antoyo): write a new function Type::is_compatible_with(&Type) and use it here
894-
// instead of using is_uint().
895-
if arg_type.is_uchar(self.cx) || arg_type.is_ushort(self.cx) || arg_type.is_uint(self.cx) {
896-
"__builtin_clz"
897-
}
898-
else if arg_type.is_ulong(self.cx) {
899-
"__builtin_clzl"
900-
}
901-
else if arg_type.is_ulonglong(self.cx) {
902-
"__builtin_clzll"
903-
}
904-
else if width == 128 {
905-
// arg is guaranteed to not be 0, so either its 64 high or 64 low bits are not 0
906-
// __buildin_clzll is UB when called with 0, so call it on the 64 high bits if they are not 0,
907-
// else call it on the 64 low bits and add 64. In the else case, 64 low bits can't be 0
908-
// because arg is not 0.
909-
910-
let result = self.current_func()
911-
.new_local(None, result_type, "count_leading_zeroes_results");
912-
913-
let ctlz_then_block = self.current_func().new_block("ctlz_then");
914-
let ctlz_else_block = self.current_func().new_block("ctlz_else");
915-
let ctlz_after_block = self.current_func().new_block("ctlz_after")
916-
;
917-
let sixty_four = self.const_uint(arg_type, 64);
918-
let shift = self.lshr(arg, sixty_four);
919-
let high = self.gcc_int_cast(shift, self.u64_type);
920-
921-
let clzll = self.context.get_builtin_function("__builtin_clzll");
922-
923-
let zero_hi = self.const_uint(high.get_type(), 0);
924-
let cond = self.gcc_icmp(IntPredicate::IntNE, high, zero_hi);
925-
self.llbb().end_with_conditional(self.location, cond, ctlz_then_block, ctlz_else_block);
926-
self.switch_to_block(ctlz_then_block);
927-
928-
let result_128 =
929-
self.gcc_int_cast(self.context.new_call(None, clzll, &[high]), result_type);
930-
931-
ctlz_then_block.add_assignment(self.location, result, result_128);
932-
ctlz_then_block.end_with_jump(self.location, ctlz_after_block);
933-
934-
self.switch_to_block(ctlz_else_block);
935-
let low = self.gcc_int_cast(arg, self.u64_type);
936-
let low_leading_zeroes =
937-
self.gcc_int_cast(self.context.new_call(None, clzll, &[low]), result_type);
938-
let sixty_four_result_type = self.const_uint(result_type, 64);
939-
let result_128 = self.add(low_leading_zeroes, sixty_four_result_type);
940-
ctlz_else_block.add_assignment(self.location, result, result_128);
941-
ctlz_else_block.end_with_jump(self.location, ctlz_after_block);
942-
self.switch_to_block(ctlz_after_block);
943-
return result.to_rvalue();
944-
}
945-
else {
946-
let count_leading_zeroes = self.context.get_builtin_function("__builtin_clzll");
947-
let arg = self.context.new_cast(self.location, arg, self.ulonglong_type);
948-
let diff = self.ulonglong_type.get_size() as i64 - arg_type.get_size() as i64;
949-
let diff = self.context.new_rvalue_from_long(self.int_type, diff * 8);
950-
let res = self.context.new_call(self.location, count_leading_zeroes, &[arg]) - diff;
951-
return self.context.new_cast(self.location, res, result_type);
918+
// TODO(antoyo): write a new function Type::is_compatible_with(&Type) and use it here
919+
// instead of using is_uint().
920+
if arg_type.is_uchar(self.cx) || arg_type.is_ushort(self.cx) || arg_type.is_uint(self.cx) {
921+
let builtin = if count_leading { "__builtin_clz" } else { "__builtin_ctz" };
922+
use_builtin_function(self, builtin, arg, arg_type, self.cx.uint_type)
923+
} else if arg_type.is_ulong(self.cx) {
924+
let builtin = if count_leading { "__builtin_clzl" } else { "__builtin_ctzl" };
925+
use_builtin_function(self, builtin, arg, arg_type, self.cx.uint_type)
926+
} else if arg_type.is_ulonglong(self.cx) {
927+
let builtin = if count_leading { "__builtin_clzll" } else { "__builtin_ctzll" };
928+
use_builtin_function(self, builtin, arg, arg_type, self.cx.uint_type)
929+
} else if width == 128 {
930+
// arg is guaranteed to not be 0, so either its 64 high or 64 low bits are not 0
931+
// __buildin_clzll is UB when called with 0, so call it on the 64 high bits if they are not 0,
932+
// else call it on the 64 low bits and add 64. In the else case, 64 low bits can't be 0
933+
// because arg is not 0.
934+
// __buildin_ctzll is UB when called with 0, so call it on the 64 low bits if they are not 0,
935+
// else call it on the 64 high bits and add 64. In the else case, 64 high bits can't be 0
936+
// because arg is not 0.
937+
938+
let result = self.current_func().new_local(None, result_type, "count_zeroes_results");
939+
940+
let cz_then_block = self.current_func().new_block("cz_then");
941+
let cz_else_block = self.current_func().new_block("cz_else");
942+
let cz_after_block = self.current_func().new_block("cz_after");
943+
944+
let low = self.gcc_int_cast(arg, self.u64_type);
945+
let sixty_four = self.const_uint(arg_type, 64);
946+
let shift = self.lshr(arg, sixty_four);
947+
let high = self.gcc_int_cast(shift, self.u64_type);
948+
949+
let (first, second, builtin) = if count_leading {
950+
(low, high, self.context.get_builtin_function("__builtin_clzll"))
951+
} else {
952+
(high, low, self.context.get_builtin_function("__builtin_ctzll"))
952953
};
953-
let count_leading_zeroes = self.context.get_builtin_function(count_leading_zeroes);
954-
let res = self.context.new_call(self.location, count_leading_zeroes, &[arg]);
955-
self.context.new_cast(self.location, res, result_type)
956-
}
957954

958-
fn count_trailing_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
959-
// if arg is 0, early return width, else call count_trailing_zeroes_nonzero to compute trailing zeros
960-
let func = self.current_func();
961-
let then_block = func.new_block("then");
962-
let else_block = func.new_block("else");
963-
let after_block = func.new_block("after");
964-
965-
let result = func.new_local(None, self.u32_type, "zeros");
966-
let zero = self.cx.gcc_zero(arg.get_type());
967-
let cond = self.gcc_icmp(IntPredicate::IntEQ, arg, zero);
968-
self.llbb().end_with_conditional(None, cond, then_block, else_block);
955+
let zero_64 = self.const_uint(self.u64_type, 0);
956+
let cond = self.gcc_icmp(IntPredicate::IntNE, second, zero_64);
957+
self.llbb().end_with_conditional(self.location, cond, cz_then_block, cz_else_block);
958+
self.switch_to_block(cz_then_block);
959+
960+
let result_128 =
961+
self.gcc_int_cast(self.context.new_call(None, builtin, &[second]), result_type);
962+
963+
cz_then_block.add_assignment(self.location, result, result_128);
964+
cz_then_block.end_with_jump(self.location, cz_after_block);
965+
966+
self.switch_to_block(cz_else_block);
967+
let count_more_zeroes =
968+
self.gcc_int_cast(self.context.new_call(None, builtin, &[first]), result_type);
969+
let sixty_four_result_type = self.const_uint(result_type, 64);
970+
let count_result_type = self.add(count_more_zeroes, sixty_four_result_type);
971+
cz_else_block.add_assignment(self.location, result, count_result_type);
972+
cz_else_block.end_with_jump(self.location, cz_after_block);
973+
self.switch_to_block(cz_after_block);
974+
result.to_rvalue()
975+
} else {
976+
let byte_diff = self.ulonglong_type.get_size() as i64 - arg_type.get_size() as i64;
977+
let diff = self.context.new_rvalue_from_long(self.int_type, byte_diff * 8);
978+
let ull_arg = self.context.new_cast(self.location, arg, self.ulonglong_type);
969979

970-
let zero_result = self.cx.gcc_uint(self.u32_type, width);
971-
then_block.add_assignment(None, result, zero_result);
972-
then_block.end_with_jump(None, after_block);
980+
let res = if count_leading {
981+
let count_leading_zeroes = self.context.get_builtin_function("__builtin_clzll");
982+
self.context.new_call(self.location, count_leading_zeroes, &[ull_arg]) - diff
983+
} else {
984+
let count_trailing_zeroes = self.context.get_builtin_function("__builtin_ctzll");
985+
let mask = self.context.new_rvalue_from_long(arg_type, -1); // To get the value with all bits set.
986+
let masked = mask
987+
& self.context.new_unary_op(
988+
self.location,
989+
UnaryOp::BitwiseNegate,
990+
arg_type,
991+
arg,
992+
);
993+
let cond =
994+
self.context.new_comparison(self.location, ComparisonOp::Equals, masked, mask);
995+
let diff = diff * self.context.new_cast(self.location, cond, self.int_type);
973996

974-
// NOTE: since jumps were added in a place count_trailing_zeroes_nonzero() does not expect,
975-
// the current block in the state need to be updated.
976-
self.switch_to_block(else_block);
997+
self.context.new_call(self.location, count_trailing_zeroes, &[ull_arg]) - diff
998+
};
999+
self.context.new_cast(self.location, res, result_type)
1000+
}
1001+
}
9771002

978-
let zeros = self.count_trailing_zeroes_nonzero(width, arg);
979-
self.llbb().add_assignment(None, result, zeros);
980-
self.llbb().end_with_jump(None, after_block);
1003+
fn count_leading_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
1004+
self.count_zeroes(width, arg, true)
1005+
}
9811006

982-
// NOTE: since jumps were added in a place rustc does not
983-
// expect, the current block in the state need to be updated.
984-
self.switch_to_block(after_block);
1007+
fn count_leading_zeroes_nonzero(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
1008+
self.count_zeroes_nonzero(width, arg, true)
1009+
}
9851010

986-
result.to_rvalue()
1011+
fn count_trailing_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
1012+
self.count_zeroes(width, arg, false)
9871013
}
9881014

989-
fn count_trailing_zeroes_nonzero(&mut self, _width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
990-
let result_type = self.u32_type;
991-
let mut arg_type = arg.get_type();
992-
let arg = if arg_type.is_signed(self.cx) {
993-
arg_type = arg_type.to_unsigned(self.cx);
994-
self.gcc_int_cast(arg, arg_type)
995-
} else {
996-
arg
997-
};
998-
let (count_trailing_zeroes, expected_type) =
999-
// TODO(antoyo): write a new function Type::is_compatible_with(&Type) and use it here
1000-
// instead of using is_uint().
1001-
if arg_type.is_uchar(self.cx) || arg_type.is_ushort(self.cx) || arg_type.is_uint(self.cx) {
1002-
// NOTE: we don't need to & 0xFF for uchar because the result is undefined on zero.
1003-
("__builtin_ctz", self.cx.uint_type)
1004-
}
1005-
else if arg_type.is_ulong(self.cx) {
1006-
("__builtin_ctzl", self.cx.ulong_type)
1007-
}
1008-
else if arg_type.is_ulonglong(self.cx) {
1009-
("__builtin_ctzll", self.cx.ulonglong_type)
1010-
}
1011-
else if arg_type.is_u128(self.cx) {
1012-
// arg is guaranteed to no be 0, so either its 64 high or 64 low bits are not 0
1013-
// __buildin_ctzll is UB when called with 0, so call it on the 64 low bits if they are not 0,
1014-
// else call it on the 64 high bits and add 64. In the else case, 64 high bits can't be 0
1015-
// because arg is not 0.
1016-
1017-
let result = self.current_func()
1018-
.new_local(None, result_type, "count_trailing_zeroes_results");
1019-
1020-
let ctlz_then_block = self.current_func().new_block("cttz_then");
1021-
let ctlz_else_block = self.current_func().new_block("cttz_else");
1022-
let ctlz_after_block = self.current_func().new_block("cttz_after");
1023-
let ctzll = self.context.get_builtin_function("__builtin_ctzll");
1024-
1025-
let low = self.gcc_int_cast(arg, self.u64_type);
1026-
let sixty_four = self.const_uint(arg_type, 64);
1027-
let shift = self.lshr(arg, sixty_four);
1028-
let high = self.gcc_int_cast(shift, self.u64_type);
1029-
let zero_low = self.const_uint(low.get_type(), 0);
1030-
let cond = self.gcc_icmp(IntPredicate::IntNE, low, zero_low);
1031-
self.llbb().end_with_conditional(self.location, cond, ctlz_then_block, ctlz_else_block);
1032-
self.switch_to_block(ctlz_then_block);
1033-
1034-
let result_128 =
1035-
self.gcc_int_cast(self.context.new_call(None, ctzll, &[low]), result_type);
1036-
1037-
ctlz_then_block.add_assignment(self.location, result, result_128);
1038-
ctlz_then_block.end_with_jump(self.location, ctlz_after_block);
1039-
1040-
self.switch_to_block(ctlz_else_block);
1041-
let high_trailing_zeroes =
1042-
self.gcc_int_cast(self.context.new_call(None, ctzll, &[high]), result_type);
1043-
1044-
let sixty_four_result_type = self.const_uint(result_type, 64);
1045-
let result_128 = self.add(high_trailing_zeroes, sixty_four_result_type);
1046-
ctlz_else_block.add_assignment(self.location, result, result_128);
1047-
ctlz_else_block.end_with_jump(self.location, ctlz_after_block);
1048-
self.switch_to_block(ctlz_after_block);
1049-
return result.to_rvalue();
1050-
}
1051-
else {
1052-
let count_trailing_zeroes = self.context.get_builtin_function("__builtin_ctzll");
1053-
let arg_size = arg_type.get_size();
1054-
let casted_arg = self.context.new_cast(self.location, arg, self.ulonglong_type);
1055-
let byte_diff = self.ulonglong_type.get_size() as i64 - arg_size as i64;
1056-
let diff = self.context.new_rvalue_from_long(self.int_type, byte_diff * 8);
1057-
let mask = self.context.new_rvalue_from_long(arg_type, -1); // To get the value with all bits set.
1058-
let masked = mask & self.context.new_unary_op(self.location, UnaryOp::BitwiseNegate, arg_type, arg);
1059-
let cond = self.context.new_comparison(self.location, ComparisonOp::Equals, masked, mask);
1060-
let diff = diff * self.context.new_cast(self.location, cond, self.int_type);
1061-
let res = self.context.new_call(self.location, count_trailing_zeroes, &[casted_arg]) - diff;
1062-
return self.context.new_cast(self.location, res, result_type);
1063-
};
1064-
let count_trailing_zeroes = self.context.get_builtin_function(count_trailing_zeroes);
1065-
let arg = if arg_type != expected_type {
1066-
self.context.new_cast(self.location, arg, expected_type)
1067-
} else {
1068-
arg
1069-
};
1070-
let res = self.context.new_call(self.location, count_trailing_zeroes, &[arg]);
1071-
self.context.new_cast(self.location, res, result_type)
1015+
fn count_trailing_zeroes_nonzero(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
1016+
self.count_zeroes_nonzero(width, arg, false)
10721017
}
10731018

10741019
fn pop_count(&mut self, value: RValue<'gcc>) -> RValue<'gcc> {

0 commit comments

Comments
 (0)