diff --git a/src/main/scala/Search/BinarySearch.scala b/src/main/scala/Search/BinarySearch.scala index 16458c2..47e3a40 100644 --- a/src/main/scala/Search/BinarySearch.scala +++ b/src/main/scala/Search/BinarySearch.scala @@ -15,7 +15,7 @@ object BinarySearch { */ def binarySearch(arr: List[Int], elem: Int): Int = { - binarySearch(arr,elem,0,arr.length) + binarySearch(arr, elem, 0, arr.length) } /** @@ -23,15 +23,25 @@ object BinarySearch { * @param elem - a integer to search for in the @args * @param fromIndex - the index of the first element (inclusive) to be searched * @param toIndex - toIndex the index of the last element (exclusive) to be searched - * @return - index of the @elem otherwise -1 + * @param returnInsertIdx - if `true`, returns info about the insertion index + * @return - index of the @elem otherwise -1. If `returnInsertIdx` is true, returns + * (-insertion_idx - 1) where insertion_idx is the 1st index where `elem` can be + * inserted into `arr` and `arr` is still sorted. */ - def binarySearch(arr: List[Int], elem: Int, fromIndex: Int, toIndex: Int): Int = { + def binarySearch( + arr: List[Int], + elem: Int, + fromIndex: Int, + toIndex: Int, + returnInsertIdx: Boolean = false + ): Int = { @tailrec def SearchImpl(lo: Int, hi: Int): Int = { - if (lo > hi) - -1 + if (lo > hi) { + if(returnInsertIdx) -lo - 1 else -1 + } else { val mid: Int = lo + (hi - lo) / 2 arr(mid) match { diff --git a/src/test/scala/Search/BinarySearchSpec.scala b/src/test/scala/Search/BinarySearchSpec.scala index 2c86ab8..c3cb604 100644 --- a/src/test/scala/Search/BinarySearchSpec.scala +++ b/src/test/scala/Search/BinarySearchSpec.scala @@ -32,4 +32,17 @@ class BinarySearchSpec extends FlatSpec { assert(BinarySearch.binarySearch(l,7,0,4) === -1) assert(BinarySearch.binarySearch(l,7,1,3) === -1) } + + it should "return insertion index if the element is not found" in { + def search(l: List[Int], elem: Int) = { + val rs = BinarySearch.binarySearch(l, elem, 0, l.length, returnInsertIdx = true) + -rs - 1 + } + + val l = List(-5, 10, 15) + assert(search(l, 0) === 1) + assert(search(l, 1) === 1) + assert(search(l, 12) === 2) + assert(search(l, 22) === 3) + } }