Skip to content

Commit c64a5b8

Browse files
committed
Implemented sort compare for lists
1 parent 43104c3 commit c64a5b8

File tree

3 files changed

+432
-5
lines changed

3 files changed

+432
-5
lines changed

crates/compiler/gen_llvm/src/llvm/sort.rs

+253-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use super::build::BuilderExt;
22
use crate::llvm::bitcode::call_bitcode_fn;
3-
use crate::llvm::build::Env;
4-
use inkwell::values::{BasicValueEnum, IntValue};
5-
use inkwell::{FloatPredicate, IntPredicate};
3+
use crate::llvm::build::{load_roc_value, Env, FAST_CALL_CONV};
4+
use crate::llvm::build_list::{list_len_usize, load_list_ptr};
5+
use crate::llvm::convert::{basic_type_from_layout, zig_list_type};
6+
use inkwell::types::BasicType;
7+
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue};
8+
use inkwell::{AddressSpace, FloatPredicate, IntPredicate};
69
use roc_builtins::bitcode::{FloatWidth, IntWidth, NUM_GREATER_THAN, NUM_LESS_THAN};
10+
use roc_module::symbol::Symbol;
711
use roc_mono::layout::{
812
Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner,
913
};
1014

1115
pub fn generic_compare<'a, 'ctx>(
1216
env: &Env<'a, 'ctx, '_>,
1317
layout_interner: &STLayoutInterner<'a>,
14-
_layout_ids: &mut LayoutIds<'a>,
18+
layout_ids: &mut LayoutIds<'a>,
1519
lhs_val: BasicValueEnum<'ctx>,
1620
rhs_val: BasicValueEnum<'ctx>,
1721
lhs_layout: InLayout<'a>,
@@ -28,7 +32,15 @@ pub fn generic_compare<'a, 'ctx>(
2832
LayoutRepr::Builtin(Builtin::Bool) => bool_compare(env, lhs_val, rhs_val),
2933
LayoutRepr::Builtin(Builtin::Decimal) => dec_compare(env, lhs_val, rhs_val),
3034
LayoutRepr::Builtin(Builtin::Str) => todo!(),
31-
LayoutRepr::Builtin(Builtin::List(_)) => todo!(),
35+
LayoutRepr::Builtin(Builtin::List(elem)) => list_compare(
36+
env,
37+
layout_interner,
38+
layout_ids,
39+
elem,
40+
layout_interner.get_repr(elem),
41+
lhs_val.into_struct_value(),
42+
rhs_val.into_struct_value(),
43+
),
3244
LayoutRepr::Struct(_) => todo!(),
3345
LayoutRepr::LambdaSet(_) => unreachable!("cannot compare closures"),
3446
LayoutRepr::FunctionPointer(_) => unreachable!("cannot compare function pointers"),
@@ -357,3 +369,239 @@ fn dec_compare<'ctx>(
357369
env.builder
358370
.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare")
359371
}
372+
373+
fn list_compare<'a, 'ctx>(
374+
env: &Env<'a, 'ctx, '_>,
375+
layout_interner: &STLayoutInterner<'a>,
376+
layout_ids: &mut LayoutIds<'a>,
377+
elem_in_layout: InLayout<'a>,
378+
element_layout: LayoutRepr<'a>,
379+
list1: StructValue<'ctx>,
380+
list2: StructValue<'ctx>,
381+
) -> IntValue<'ctx> {
382+
let block = env.builder.get_insert_block().expect("to be in a function");
383+
let di_location = env.builder.get_current_debug_location().unwrap();
384+
385+
let symbol = Symbol::LIST_COMPARE;
386+
let element_layout = if let LayoutRepr::RecursivePointer(rec) = element_layout {
387+
layout_interner.get_repr(rec)
388+
} else {
389+
element_layout
390+
};
391+
let fn_name = layout_ids
392+
.get(symbol, &element_layout)
393+
.to_symbol_string(symbol, &env.interns);
394+
395+
let function = match env.module.get_function(fn_name.as_str()) {
396+
Some(function_value) => function_value,
397+
None => {
398+
let arg_type = zig_list_type(env).into();
399+
400+
let function_value = crate::llvm::refcounting::build_header_help(
401+
env,
402+
&fn_name,
403+
env.context.i8_type().into(),
404+
&[arg_type, arg_type],
405+
);
406+
407+
list_compare_help(
408+
env,
409+
layout_interner,
410+
layout_ids,
411+
function_value,
412+
elem_in_layout,
413+
element_layout,
414+
);
415+
416+
function_value
417+
}
418+
};
419+
420+
env.builder.position_at_end(block);
421+
env.builder.set_current_debug_location(di_location);
422+
let call = env
423+
.builder
424+
.new_build_call(function, &[list1.into(), list2.into()], "list_cmp");
425+
426+
call.set_call_convention(FAST_CALL_CONV);
427+
428+
call.try_as_basic_value().left().unwrap().into_int_value()
429+
}
430+
431+
fn list_compare_help<'a, 'ctx>(
432+
env: &Env<'a, 'ctx, '_>,
433+
layout_interner: &STLayoutInterner<'a>,
434+
layout_ids: &mut LayoutIds<'a>,
435+
parent: FunctionValue<'ctx>,
436+
elem_in_layout: InLayout<'a>,
437+
element_layout: LayoutRepr<'a>,
438+
) {
439+
let ctx = env.context;
440+
let builder = env.builder;
441+
442+
{
443+
use inkwell::debug_info::AsDIScope;
444+
445+
let func_scope = parent.get_subprogram().unwrap();
446+
let lexical_block = env.dibuilder.create_lexical_block(
447+
/* scope */ func_scope.as_debug_info_scope(),
448+
/* file */ env.compile_unit.get_file(),
449+
/* line_no */ 0,
450+
/* column_no */ 0,
451+
);
452+
453+
let loc = env.dibuilder.create_debug_location(
454+
ctx,
455+
/* line */ 0,
456+
/* column */ 0,
457+
/* current_scope */ lexical_block.as_debug_info_scope(),
458+
/* inlined_at */ None,
459+
);
460+
builder.set_current_debug_location(loc);
461+
}
462+
463+
// Add args to scope
464+
let mut it = parent.get_param_iter();
465+
let list1 = it.next().unwrap().into_struct_value();
466+
let list2 = it.next().unwrap().into_struct_value();
467+
468+
list1.set_name(Symbol::ARG_1.as_str(&env.interns));
469+
list2.set_name(Symbol::ARG_2.as_str(&env.interns));
470+
471+
let entry = ctx.append_basic_block(parent, "entry");
472+
let loop_bb = ctx.append_basic_block(parent, "loop_bb");
473+
let end_l1_bb = ctx.append_basic_block(parent, "end_l1_bb");
474+
let in_l1_bb = ctx.append_basic_block(parent, "in_l1_bb");
475+
let elem_compare_bb = ctx.append_basic_block(parent, "increment_bb");
476+
let not_eq_elems_bb = ctx.append_basic_block(parent, "not_eq_elems_bb");
477+
let increment_bb = ctx.append_basic_block(parent, "increment_bb");
478+
let return_eq = ctx.append_basic_block(parent, "return_eq");
479+
let return_gt = ctx.append_basic_block(parent, "return_gt");
480+
let return_lt = ctx.append_basic_block(parent, "return_lt");
481+
482+
builder.position_at_end(entry);
483+
let len1 = list_len_usize(builder, list1);
484+
let len2 = list_len_usize(builder, list2);
485+
486+
// allocate a stack slot for the current index
487+
let index_alloca = builder.new_build_alloca(env.ptr_int(), "index");
488+
builder.new_build_store(index_alloca, env.ptr_int().const_zero());
489+
490+
builder.new_build_unconditional_branch(loop_bb);
491+
492+
builder.position_at_end(loop_bb);
493+
494+
// load the current index
495+
let index = builder
496+
.new_build_load(env.ptr_int(), index_alloca, "index")
497+
.into_int_value();
498+
499+
// true if there are no more elements in list 1
500+
let end_l1_cond = builder.new_build_int_compare(IntPredicate::EQ, len1, index, "end_l1_cond");
501+
502+
builder.new_build_conditional_branch(end_l1_cond, end_l1_bb, in_l1_bb);
503+
504+
{
505+
builder.position_at_end(end_l1_bb);
506+
507+
// true if there are no more elements in list 2
508+
let eq_cond = builder.new_build_int_compare(IntPredicate::EQ, len2, index, "eq_cond");
509+
510+
// if both list have no more elements, eq
511+
// else, list 2 still has more elements, so lt
512+
builder.new_build_conditional_branch(eq_cond, return_eq, return_lt);
513+
}
514+
515+
{
516+
builder.position_at_end(in_l1_bb);
517+
518+
// list 2 has no more elements
519+
let gt_cond = builder.new_build_int_compare(IntPredicate::EQ, len2, index, "gt_cond");
520+
521+
// if list 2 has no more elements, since list 1 still has more, gt
522+
// else, compare the elements at the current index
523+
builder.new_build_conditional_branch(gt_cond, return_gt, elem_compare_bb);
524+
}
525+
526+
{
527+
builder.position_at_end(elem_compare_bb);
528+
529+
let element_type = basic_type_from_layout(env, layout_interner, element_layout);
530+
let ptr_type = element_type.ptr_type(AddressSpace::default());
531+
let ptr1 = load_list_ptr(builder, list1, ptr_type);
532+
let ptr2 = load_list_ptr(builder, list2, ptr_type);
533+
534+
let elem1 = {
535+
let elem_ptr = unsafe {
536+
builder.new_build_in_bounds_gep(element_type, ptr1, &[index], "load_index")
537+
};
538+
load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem")
539+
};
540+
541+
let elem2 = {
542+
let elem_ptr = unsafe {
543+
builder.new_build_in_bounds_gep(element_type, ptr2, &[index], "load_index")
544+
};
545+
load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem")
546+
};
547+
548+
let elem_cmp = generic_compare(
549+
env,
550+
layout_interner,
551+
layout_ids,
552+
elem1,
553+
elem2,
554+
elem_in_layout,
555+
elem_in_layout,
556+
)
557+
.into_int_value();
558+
559+
// true if elements are equal
560+
let increment_cond = builder.new_build_int_compare(
561+
IntPredicate::EQ,
562+
elem_cmp,
563+
ctx.i8_type().const_int(0, false),
564+
"increment_cond",
565+
);
566+
567+
// if elements are equal, increment the pointers
568+
// else, return gt or lt
569+
builder.new_build_conditional_branch(increment_cond, increment_bb, not_eq_elems_bb);
570+
571+
{
572+
builder.position_at_end(not_eq_elems_bb);
573+
574+
// When elements compare not equal, we return the element comparison
575+
builder.new_build_return(Some(&elem_cmp));
576+
}
577+
}
578+
579+
{
580+
builder.position_at_end(increment_bb);
581+
582+
let one = env.ptr_int().const_int(1, false);
583+
584+
// increment the index
585+
let next_index = builder.new_build_int_add(index, one, "nextindex");
586+
587+
builder.new_build_store(index_alloca, next_index);
588+
589+
// jump back to the top of the loop
590+
builder.new_build_unconditional_branch(loop_bb);
591+
}
592+
593+
{
594+
builder.position_at_end(return_eq);
595+
builder.new_build_return(Some(&ctx.i8_type().const_int(0, false)));
596+
}
597+
598+
{
599+
builder.position_at_end(return_gt);
600+
builder.new_build_return(Some(&ctx.i8_type().const_int(1, false)));
601+
}
602+
603+
{
604+
builder.position_at_end(return_lt);
605+
builder.new_build_return(Some(&ctx.i8_type().const_int(2, false)));
606+
}
607+
}

crates/compiler/solve/src/ability.rs

+12
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,18 @@ impl DerivableVisitor for DeriveSort {
872872
true
873873
}
874874

875+
#[inline(always)]
876+
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
877+
if matches!(symbol, Symbol::LIST_LIST,) {
878+
Ok(Descend(true))
879+
} else {
880+
Err(NotDerivable {
881+
var,
882+
context: NotDerivableContext::NoContext,
883+
})
884+
}
885+
}
886+
875887
#[inline(always)]
876888
fn visit_floating_point_content(
877889
_var: Variable,

0 commit comments

Comments
 (0)