diff --git a/src/saturn.ml b/src/saturn.ml index a71a9903..caa0a91e 100644 --- a/src/saturn.ml +++ b/src/saturn.ml @@ -32,4 +32,5 @@ module Work_stealing_deque = Lockfree.Work_stealing_deque module Single_prod_single_cons_queue = Lockfree.Single_prod_single_cons_queue module Single_consumer_queue = Lockfree.Single_consumer_queue module Relaxed_queue = Mpmc_relaxed_queue +module Priority_queue = Lockfree.Priority_queue module Backoff = Lockfree.Backoff diff --git a/src/saturn.mli b/src/saturn.mli index 7a5aa1fa..417e976e 100644 --- a/src/saturn.mli +++ b/src/saturn.mli @@ -36,6 +36,7 @@ module Work_stealing_deque = Lockfree.Work_stealing_deque module Single_prod_single_cons_queue = Lockfree.Single_prod_single_cons_queue module Single_consumer_queue = Lockfree.Single_consumer_queue module Relaxed_queue = Mpmc_relaxed_queue +module Priority_queue = Lockfree.Priority_queue module Backoff = Lockfree.Backoff (** {2 Other} *) diff --git a/src_lockfree/lockfree.ml b/src_lockfree/lockfree.ml index ca2e063a..1d20a4ca 100644 --- a/src_lockfree/lockfree.ml +++ b/src_lockfree/lockfree.ml @@ -32,4 +32,5 @@ module Work_stealing_deque = Ws_deque module Single_prod_single_cons_queue = Spsc_queue module Single_consumer_queue = Mpsc_queue module Relaxed_queue = Mpmc_relaxed_queue +module Priority_queue = Priority_queue module Backoff = Backoff diff --git a/src_lockfree/lockfree.mli b/src_lockfree/lockfree.mli index d70939f5..cbf9fe59 100644 --- a/src_lockfree/lockfree.mli +++ b/src_lockfree/lockfree.mli @@ -36,6 +36,7 @@ module Work_stealing_deque = Ws_deque module Single_prod_single_cons_queue = Spsc_queue module Single_consumer_queue = Mpsc_queue module Relaxed_queue = Mpmc_relaxed_queue +module Priority_queue = Priority_queue (** {2 Other} *) diff --git a/src_lockfree/priority_queue.ml b/src_lockfree/priority_queue.ml new file mode 100644 index 00000000..b7f9b478 --- /dev/null +++ b/src_lockfree/priority_queue.ml @@ -0,0 +1,269 @@ +type markable_reference = { node : node; marked : bool } +(** markable reference: stores a reference to a node and has a field to specify if the original node is marked *) + +and node = { + key : int; + height : int; + logical_mark : bool Atomic.t; + next : markable_reference Atomic.t array; +} + +exception Failed_snip + +type t = { head : node; max_height : int } + +let null_node = + { + key = Int.max_int; + height = 0; + logical_mark = Atomic.make true; + next = [||]; + } + +(** create_new_node: creates a new node with some value and height *) +let create_new_node value height = + let next = + Array.init (height + 1) (fun _ -> + Atomic.make { node = null_node; marked = false }) + in + { key = value; height; logical_mark = Atomic.make true; next } + +(** create_dummy_node_array: Creates a new array with the different node for each index *) +let create_dummy_node_array sl = + let arr = Array.make (sl.max_height + 1) null_node in + arr + +(** Get a random level from 0 till max_height (both included), the node will be assigned this height *) +let get_random_level sl = + let rec count_level cur_level = + if cur_level == sl.max_height || Random.bool () then cur_level + else count_level (cur_level + 1) + in + count_level 0 + +(** Create a new skiplist *) +let create ?(max_height = 10) () = + let tail = create_new_node Int.max_int max_height in + let next = + Array.init (max_height + 1) (fun _ -> + Atomic.make { node = tail; marked = false }) + in + let head = + { + key = Int.min_int; + height = max_height; + logical_mark = Atomic.make false; + next; + } + in + { head; max_height } + +(** Compares old_node and old_mark with the atomic reference and if they are the same then + Replaces the value in the atomic with node and mark *) +let compare_and_set_mark_ref (atomic, old_node, old_mark, node, mark) = + let current = Atomic.get atomic in + let set_mark_ref () = + Atomic.compare_and_set atomic current { node; marked = mark } + in + let current_node = current.node in + current_node == old_node && current.marked = old_mark + && ((current_node == node && current.marked = mark) || set_mark_ref ()) + +(** Returns true if key is found within the skiplist else false; + Irrespective of return value, fills the preds and succs array with + the predecessors nodes with smaller key and successors nodes with greater than + or equal to key + *) +let find_in (key, preds, succs, sl, is_del) = + let head = sl.head in + let rec iterate (prev, curr, succ, mark, level) = + if mark then + (* need to delete curr if marked, so update prev next ptr to succ *) + let snip = + compare_and_set_mark_ref (prev.next.(level), curr, false, succ, false) + in + if not snip then raise Failed_snip + else + let { node = curr; marked = _ } = Atomic.get prev.next.(level) in + let { node = succ; marked = mark } = Atomic.get curr.next.(level) in + iterate (prev, curr, succ, mark, level) + else if (not is_del) && curr.key <= key then + (* keep traversing to get key greater than or equal *) + let { node = new_succ; marked = mark } = Atomic.get succ.next.(level) in + iterate (curr, succ, new_succ, mark, level) + else if is_del && curr.key < key then + (* keep traversing to get key greater than or equal *) + let { node = new_succ; marked = mark } = Atomic.get succ.next.(level) in + iterate (curr, succ, new_succ, mark, level) + else (prev, curr) + in + (* find pred and succ at that level *) + let rec update_arrays prev level = + let { node = curr; marked = _ } = Atomic.get prev.next.(level) in + let { node = succ; marked = mark } = Atomic.get curr.next.(level) in + try + let prev, curr = iterate (prev, curr, succ, mark, level) in + (* prev <= key < curr *) + preds.(level) <- prev; + succs.(level) <- curr; + if level > 0 then update_arrays prev (level - 1) else curr.key == key + with Failed_snip -> update_arrays head sl.max_height + in + update_arrays head sl.max_height + +(** Adds a new key to the skiplist sl. *) +let push sl key = + let top_level = get_random_level sl in + let preds = create_dummy_node_array sl in + let succs = create_dummy_node_array sl in + let rec repeat () = + (* check if key already exists and fill preds and succs *) + find_in (key, preds, succs, sl, false) |> ignore; + let new_node_next = + (* build next array based on succs *) + Array.map + (fun element -> + let mark_ref = { node = element; marked = false } in + Atomic.make mark_ref) + succs + in + let new_node = + { + key; + height = top_level; + logical_mark = Atomic.make false; + next = new_node_next; + } + in + let pred = preds.(0) in + let succ = succs.(0) in + (* insert at level 0 *) + if + not + (compare_and_set_mark_ref (pred.next.(0), succ, false, new_node, false)) + then repeat () + else + let rec update_levels level = + let rec set_next () = + let pred = preds.(level) in + let succ = succs.(level) in + if + compare_and_set_mark_ref + (pred.next.(level), succ, false, new_node, false) + then () + else ( + find_in (key, preds, succs, sl, false) |> ignore; + set_next ()) + in + set_next (); + if level < top_level then update_levels (level + 1) + in + if top_level > 0 then update_levels 1; + (* start updating from level 1 and then move upwards *) + () + in + repeat () + +(** Returns true if the key is within the skiplist, else returns false *) +let contains sl key = + let rec search (pred, curr, succ, mark, level) = + if mark then + (* to be deleted *) + let curr = succ in + let { node = succ; marked = mark } = Atomic.get curr.next.(level) in + search (pred, curr, succ, mark, level) + else if curr.key < key then + (* keep iterating to find correct position *) + let pred = curr in + let curr = succ in + let { node = succ; marked = mark } = Atomic.get curr.next.(level) in + search (pred, curr, succ, mark, level) + else if level > 0 then + (* found correct position, find exact level *) + let level = level - 1 in + let { node = curr; marked = _ } = Atomic.get pred.next.(level) in + let { node = succ; marked = mark } = Atomic.get curr.next.(level) in + search (pred, curr, succ, mark, level) + else + curr.key == key (* at the most accurate position, check if key exists *) + in + let pred = sl.head in + let { node = curr; marked = _ } = Atomic.get pred.next.(sl.max_height) in + let { node = succ; marked = mark } = Atomic.get curr.next.(sl.max_height) in + search (pred, curr, succ, mark, sl.max_height) + +(* find the minimum node on the bottom level and mark it as deleted, + important to refetch successor node because something could have changed in between *) +let find_mark_min sl = + let rec find_unmarked curr = + let { node = not_tail; marked = _ } = Atomic.get curr.next.(0) in + if not_tail != null_node then + if + (not (Atomic.get curr.logical_mark)) + && Atomic.compare_and_set curr.logical_mark false true + then curr + else + let { node = succ; marked = _ } = Atomic.get curr.next.(0) in + find_unmarked succ + else null_node + in + let { node = curr; marked = _ } = Atomic.get sl.head.next.(0) in + find_unmarked curr + +(** Removes given key from skiplist, unlinking the next pointers *) +let remove sl key = + let preds = create_dummy_node_array sl in + let succs = create_dummy_node_array sl in + let rec repeat () = + find_in (key, preds, succs, sl, true) |> ignore; + let nodeToRemove = succs.(0) in + (* expected node to remove based on given key *) + let nodeHeight = nodeToRemove.height in + let rec mark_levels succ level = + (* set node to marked *) + let _ = + compare_and_set_mark_ref + (nodeToRemove.next.(level), succ, false, succ, true) + in + let { node = succ; marked = mark } = + Atomic.get nodeToRemove.next.(level) + in + if not mark then + mark_levels succ level (* some update happened to next so retry *) + in + let rec update_upper_levels level = + (* from node height to 1 *) + let { node = succ; marked = mark } = + Atomic.get nodeToRemove.next.(level) + in + if not mark then mark_levels succ level; + if level > 1 then update_upper_levels (level - 1) + in + let rec update_bottom_level succ = + (* for bottom level only *) + let iMarkedIt = + compare_and_set_mark_ref (nodeToRemove.next.(0), succ, false, succ, true) + in + let { node = succ; marked = mark } = Atomic.get succs.(0).next.(0) in + if iMarkedIt then ( + (* update next links to remove marked node in all levels *) + find_in (key, preds, succs, sl, true) |> ignore;) + else if mark then repeat () (* some other thread deleted same key *) + else + update_bottom_level + succ (* retry because some update happened in between *) + in + if nodeHeight > 0 then update_upper_levels nodeHeight; + let { node = succ; marked = _ } = Atomic.get nodeToRemove.next.(0) in + update_bottom_level succ + in + repeat () + +(** remove smallest key from priority queue, first mark logically and then physical removal *) +let pop sl = + let num = find_mark_min sl in + if num != null_node then ( + remove sl num.key; + num.key) + else null_node.key + \ No newline at end of file diff --git a/src_lockfree/priority_queue.mli b/src_lockfree/priority_queue.mli new file mode 100644 index 00000000..36420b62 --- /dev/null +++ b/src_lockfree/priority_queue.mli @@ -0,0 +1,20 @@ +(** + Lockfree Priority Queue implementation based on skiplist. The references include + chapter 14 & 15 in art of multiprocessor programming as well as the following { + {:http://people.csail.mit.edu/shanir/publications/Priority_Queues.pdf} research paper} +*) + +type t +(** the type of priority queue *) + +val create : ?max_height:int -> unit -> t +(** create new pq with given height *) + +val push : t -> int -> unit +(** [push pq ele] adds [ele] to it's sorted position in [pq] *) + +val pop : t -> int +(** [pop pq] removes smallest elements from [pq] *) + +val contains : t -> int -> bool +(** [contains pq ele] checks if [ele] exists in [pq] *) \ No newline at end of file diff --git a/test/priority_queue/dscheck_priority_queue.ml b/test/priority_queue/dscheck_priority_queue.ml new file mode 100644 index 00000000..2d9dbf4f --- /dev/null +++ b/test/priority_queue/dscheck_priority_queue.ml @@ -0,0 +1,113 @@ +open Priority_queue + +let _two_mem () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 () in + let found1 = ref false in + let found2 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + found1 := contains sl 1); + + Atomic.spawn (fun () -> found2 := contains sl 2); + + Atomic.final (fun () -> Atomic.check (fun () -> !found1 && not !found2))) + +let _two_mem_same () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 () in + let found1 = ref false in + let found2 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + found1 := contains sl 1); + + Atomic.spawn (fun () -> + push sl 1; + found2 := contains sl 1); + + Atomic.final (fun () -> Atomic.check (fun () -> !found1 && !found2))) + +let _extra_remove () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 () in + let removed1 = ref false in + let removed2 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + removed1 := pop sl <> Int.max_int); + Atomic.spawn (fun () -> removed2 := pop sl <> Int.max_int); + + Atomic.final (fun () -> + Atomic.check (fun () -> + ((!removed1 && not !removed2) || ((not !removed1) && !removed2)) + && not (contains sl 1)))) + +let _two_remove () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:1 () in + let removed1 = ref false in + let removed2 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + removed1 := pop sl <> Int.max_int); + Atomic.spawn (fun () -> + push sl 2; + removed2 := pop sl <> Int.max_int); + + Atomic.final (fun () -> Atomic.check (fun () -> !removed1 && !removed2))) + +let _remove_mem () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:1 () in + let removed1 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + removed1 := pop sl <> Int.max_int); + Atomic.spawn (fun () -> push sl 1); + + Atomic.final (fun () -> + Atomic.check (fun () -> !removed1 && contains sl 1))) + +let _two_remove_same () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:1 () in + let removed1 = ref false in + let removed2 = ref false in + + Atomic.spawn (fun () -> + push sl 1; + removed1 := pop sl = 1); + Atomic.spawn (fun () -> + push sl 1; + removed2 := pop sl = 1); + + Atomic.final (fun () -> + Atomic.check (fun () -> + (!removed1 && !removed2) && not (contains sl 1)))) + +let () = + let open Alcotest in + run "lockfree_pq_dscheck" + [ + ( "basic", + [ + test_case "2-mem" `Slow _two_mem; + test_case "2-mem-same" `Slow _two_mem_same; + test_case "extra-remove" `Slow _extra_remove; + test_case "2-remove" `Slow _two_remove; + test_case "remove-mem" `Slow _remove_mem; + test_case "2-remove-same" `Slow _two_remove_same; + ] ); + ] diff --git a/test/priority_queue/dune b/test/priority_queue/dune new file mode 100644 index 00000000..7a24c902 --- /dev/null +++ b/test/priority_queue/dune @@ -0,0 +1,12 @@ +(rule + (copy ../../src_lockfree/priority_queue.ml priority_queue.ml)) + +(test + (name qcheck_priority_queue) + (libraries saturn qcheck qcheck-alcotest) + (modules qcheck_priority_queue)) + +(test + (name dscheck_priority_queue) + (libraries atomic dscheck alcotest) + (modules priority_queue dscheck_priority_queue)) diff --git a/test/priority_queue/qcheck_priority_queue.ml b/test/priority_queue/qcheck_priority_queue.ml new file mode 100644 index 00000000..62b35b27 --- /dev/null +++ b/test/priority_queue/qcheck_priority_queue.ml @@ -0,0 +1,190 @@ +open Saturn + +let tests_sequential = + QCheck. + [ + (* TEST 1 - push, pop check order *) + Test.make ~count:1000 ~name:"push_pop_check_order" small_nat (fun len -> + assume (len <> 0); + (* Building a random queue *) + let lpush = List.init len (fun i -> i) in + let queue = Priority_queue.create ~max_height:20 () in + List.iter (fun ele -> ignore @@ Priority_queue.push queue ele) lpush; + + (* Popping until [is_empty q] is true *) + let out = ref [] in + let insert v = out := v :: !out in + let count = ref 0 in + while !count < len do + incr count; + let num = Priority_queue.pop queue in + insert num + done; + + (* Testing property *) + !out = List.rev lpush); + (* TEST 2 - push, pop check order random *) + Test.make ~count:1000 ~name:"push_pop_check_order_random" small_nat + (fun len -> + assume (len <> 0); + (* Building a random queue *) + Random.self_init (); + let queue = Priority_queue.create ~max_height:20 () in + let lpush = ref [] in + for _ = 1 to len do + let x = Random.int 100_000 in + Priority_queue.push queue x; + lpush := x :: !lpush + done; + (* Popping until [is_empty q] is true *) + let out = ref [] in + let insert v = out := v :: !out in + let count = ref 0 in + while !count < List.length !lpush do + incr count; + let num = Priority_queue.pop queue in + insert num + done; + + (* Testing property *) + List.rev !out = List.sort compare !lpush); + ] + +let tests_domains = + QCheck. + [ + (* TEST 1 - 2 producers, check pop order *) + Test.make ~count:100 ~name:"double_add_remove" small_nat (fun nlen -> + assume (nlen <> 0); + Random.self_init (); + let rlen = Random.int 1000 in + let len = nlen + rlen in + (* Creating a queue *) + let lpush1 = List.init len (fun _ -> Random.int 100) in + let lpush2 = List.init len (fun _ -> Random.int 100) in + let queue = Priority_queue.create ~max_height:20 () in + + let producer1 = + Domain.spawn (fun () -> + List.iter (fun ele -> Priority_queue.push queue ele) lpush1) + in + let producer2 = + Domain.spawn (fun () -> + List.iter (fun ele -> Priority_queue.push queue ele) lpush2) + in + Domain.join producer1; + Domain.join producer2; + + let out = ref [] in + let insert v = out := v :: !out in + let count = ref 0 in + while !count < 2 * len do + incr count; + let num = Priority_queue.pop queue in + insert num + done; + (* Testing property *) + List.rev !out = List.sort compare (lpush1 @ lpush2)); + (* TEST 2 - 2 consumers, check order *) + Test.make ~count:100 ~name:"add_double_remove" small_nat (fun nlen -> + assume (nlen <> 0); + Random.self_init (); + let len = nlen + 1000 in + let plen = 2 * len in + (* Creating a queue *) + let lpush = List.init plen (fun _ -> Random.int 100) in + let queue = Priority_queue.create ~max_height:20 () in + List.iter (fun ele -> Priority_queue.push queue ele) lpush; + let c1 = ref 0 in + let c2 = ref 0 in + let consumer1 = + Domain.spawn (fun () -> + while !c1 < len do + let num = Priority_queue.pop queue in + if num <> Int.max_int then incr c1 + done) + in + let consumer2 = + Domain.spawn (fun () -> + while !c2 < len do + let num = Priority_queue.pop queue in + if num <> Int.max_int then incr c2 + done) + in + Domain.join consumer1; + Domain.join consumer2; + !c1 + !c2 = plen); + (* TEST 3 - Same domain add remove *) + Test.make ~count:100 ~name:"parallel_add_remove" small_nat (fun slen -> + assume (slen <> 0); + Random.self_init (); + (* Creating a queue *) + let queue = Priority_queue.create ~max_height:20 () in + let len = slen + 5000 in + let c1 = ref 0 in + let c2 = ref 0 in + let c3 = ref 0 in + let c4 = ref 0 in + for _ = 1 to len do + let ele = Random.int 500 in + Priority_queue.push queue ele |> ignore + done; + let d1 = + Domain.spawn (fun () -> + for _ = 1 to len do + Priority_queue.push queue (Random.int 500); + incr c1; + let num = Priority_queue.pop queue in + if num <> Int.max_int then decr c1 + done) + in + let d2 = + Domain.spawn (fun () -> + for _ = 1 to len do + Priority_queue.push queue (Random.int 500); + incr c2; + let num = Priority_queue.pop queue in + if num <> Int.max_int then decr c2 + done) + in + let d3 = + Domain.spawn (fun () -> + for _ = len downto 1 do + Priority_queue.push queue (Random.int 500); + incr c3; + let num = Priority_queue.pop queue in + if num <> Int.max_int then decr c3 + done) + in + let d4 = + Domain.spawn (fun () -> + for _ = len downto 1 do + Priority_queue.push queue (Random.int 500); + incr c4; + let num = Priority_queue.pop queue in + if num <> Int.max_int then decr c4 + done) + in + Domain.join d1; + Domain.join d2; + Domain.join d3; + Domain.join d4; + let c5 = ref len in + for _ = 1 to len do + let num = Priority_queue.pop queue in + if num <> Int.max_int then decr c5 else Format.printf "%d\n%!" 42 + done; + !c1 = 0 && !c2 = 0 && !c3 = 0 && !c4 = 0 && !c5 = 0); + ] + +let main () = + (* QCheck_base_runner.set_seed 124752466; *) + let to_alcotest = List.map QCheck_alcotest.to_alcotest in + Alcotest.run "Priority_queue" + [ + ("test_sequential", to_alcotest tests_sequential); + ("test_domains", to_alcotest tests_domains); + ] +;; + +main ()