1
1
use super :: build:: BuilderExt ;
2
2
use crate :: llvm:: build:: Env ;
3
+ use crate :: llvm:: bitcode:: call_bitcode_fn;
3
4
use inkwell:: values:: { BasicValueEnum , IntValue } ;
4
5
use inkwell:: { IntPredicate , FloatPredicate } ;
5
- use roc_builtins:: bitcode:: { IntWidth , FloatWidth } ;
6
+ use roc_builtins:: bitcode:: { IntWidth , FloatWidth , NUM_LESS_THAN , NUM_GREATER_THAN } ;
6
7
use roc_mono:: layout:: {
7
8
Builtin , InLayout , LayoutIds , LayoutInterner , LayoutRepr , STLayoutInterner ,
8
9
} ;
@@ -27,7 +28,9 @@ pub fn generic_compare<'a, 'ctx>(
27
28
LayoutRepr :: Builtin ( Builtin :: Bool ) => {
28
29
bool_compare ( env, lhs_val, rhs_val)
29
30
}
30
- LayoutRepr :: Builtin ( Builtin :: Decimal ) => todo ! ( ) ,
31
+ LayoutRepr :: Builtin ( Builtin :: Decimal ) => {
32
+ dec_compare ( env, lhs_val, rhs_val)
33
+ }
31
34
LayoutRepr :: Builtin ( Builtin :: Str ) => todo ! ( ) ,
32
35
LayoutRepr :: Builtin ( Builtin :: List ( _) ) => todo ! ( ) ,
33
36
LayoutRepr :: Struct ( _) => todo ! ( ) ,
@@ -237,7 +240,7 @@ fn float_cmp<'ctx>(
237
240
let two = env. context . i8_type ( ) . const_int ( 2 , false ) ;
238
241
let three = env. context . i8_type ( ) . const_int ( 3 , false ) ;
239
242
240
- let lt_test = make_cmp ( FloatPredicate :: OLT , lhs_val, rhs_val, "rhs_lt_lhs " ) ;
243
+ let lt_test = make_cmp ( FloatPredicate :: OLT , lhs_val, rhs_val, "lhs_lt_rhs " ) ;
241
244
let gt_test = make_cmp ( FloatPredicate :: OGT , lhs_val, rhs_val, "lhs_gt_rhs" ) ;
242
245
let eq_test = make_cmp ( FloatPredicate :: OEQ , lhs_val, rhs_val, "lhs_eq_rhs" ) ;
243
246
let lhs_not_nan_test = make_cmp ( FloatPredicate :: OEQ , lhs_val, lhs_val, "lhs_not_NaN" ) ;
@@ -263,18 +266,45 @@ fn bool_compare<'ctx>(
263
266
rhs_val : BasicValueEnum < ' ctx > ,
264
267
) -> IntValue < ' ctx > {
265
268
269
+ // Cast the input bools to ints because int comparison of bools does the opposite of what one would expect.
270
+ // I could just swap the arguments, but I do not want to rely on behavior which seems wrong
271
+ let lhs_byte = env. builder . new_build_int_cast_sign_flag ( lhs_val. into_int_value ( ) , env. context . i8_type ( ) , false , "lhs_byte" ) ;
272
+ let rhs_byte = env. builder . new_build_int_cast_sign_flag ( rhs_val. into_int_value ( ) , env. context . i8_type ( ) , false , "rhs_byte" ) ;
273
+
266
274
// (a < b)
267
- let lhs_lt_rhs = env. builder . new_build_int_compare ( IntPredicate :: SLT , lhs_val . into_int_value ( ) , rhs_val . into_int_value ( ) , "lhs_lt_rhs_bool" ) ;
275
+ let lhs_lt_rhs = env. builder . new_build_int_compare ( IntPredicate :: SLT , lhs_byte , rhs_byte , "lhs_lt_rhs_bool" ) ;
268
276
let lhs_lt_rhs_byte = env. builder . new_build_int_cast_sign_flag ( lhs_lt_rhs, env. context . i8_type ( ) , false , "lhs_lt_rhs_byte" ) ;
269
277
270
278
// (a > b)
271
- let lhs_gt_rhs = env. builder . new_build_int_compare ( IntPredicate :: SGT , lhs_val . into_int_value ( ) , rhs_val . into_int_value ( ) , "lhs_gt_rhs_bool" ) ;
279
+ let lhs_gt_rhs = env. builder . new_build_int_compare ( IntPredicate :: SGT , lhs_byte , rhs_byte , "lhs_gt_rhs_bool" ) ;
272
280
let lhs_gt_rhs_byte = env. builder . new_build_int_cast_sign_flag ( lhs_gt_rhs, env. context . i8_type ( ) , false , "lhs_gt_rhs_byte" ) ;
273
281
274
- // (a > b) * 2
282
+ // (a < b) * 2
283
+ let two = env. context . i8_type ( ) . const_int ( 2 , false ) ;
284
+ let lhs_lt_rhs_times_two = env. builder . new_build_int_mul ( lhs_lt_rhs_byte, two, "lhs_lt_rhs_times_two" ) ;
285
+
286
+ // (a > b) + (a < b) * 2
287
+ env. builder . new_build_int_add ( lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare" )
288
+ }
289
+
290
+ fn dec_compare < ' ctx > (
291
+ env : & Env < ' _ , ' ctx , ' _ > ,
292
+ lhs_val : BasicValueEnum < ' ctx > ,
293
+ rhs_val : BasicValueEnum < ' ctx > ,
294
+ ) -> IntValue < ' ctx > {
295
+
296
+ // (a > b)
297
+ let lhs_gt_rhs = call_bitcode_fn ( env, & [ lhs_val, rhs_val] , & NUM_GREATER_THAN [ IntWidth :: I128 ] ) . into_int_value ( ) ;
298
+ let lhs_gt_rhs_byte = env. builder . new_build_int_cast_sign_flag ( lhs_gt_rhs, env. context . i8_type ( ) , false , "lhs_gt_rhs_byte" ) ;
299
+
300
+ // (a < b)
301
+ let lhs_lt_rhs = call_bitcode_fn ( env, & [ lhs_val, rhs_val] , & NUM_LESS_THAN [ IntWidth :: I128 ] ) . into_int_value ( ) ;
302
+ let lhs_lt_rhs_byte = env. builder . new_build_int_cast_sign_flag ( lhs_lt_rhs, env. context . i8_type ( ) , false , "lhs_lt_rhs_byte" ) ;
303
+
304
+ // (a < b) * 2
275
305
let two = env. context . i8_type ( ) . const_int ( 2 , false ) ;
276
- let lhs_gt_rhs_times_two = env. builder . new_build_int_mul ( lhs_gt_rhs_byte , two, "lhs_gt_rhs_times_two" ) ;
306
+ let lhs_lt_rhs_times_two = env. builder . new_build_int_mul ( lhs_lt_rhs_byte , two, "lhs_gt_rhs_times_two" ) ;
277
307
278
- // (a < b) + (a > b) * 2
279
- env. builder . new_build_int_add ( lhs_lt_rhs_byte , lhs_gt_rhs_times_two , "bool_compare" )
308
+ // (a > b) + (a < b) * 2
309
+ env. builder . new_build_int_add ( lhs_gt_rhs_byte , lhs_lt_rhs_times_two , "bool_compare" )
280
310
}
0 commit comments