diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..9ac3804 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.5 diff --git a/examples/notebook/safe_string_table_lookup.ipynb b/examples/notebook/safe_string_table_lookup.ipynb new file mode 100644 index 0000000..254b8b9 --- /dev/null +++ b/examples/notebook/safe_string_table_lookup.ipynb @@ -0,0 +1,607 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Safe String Lookup\n", + "\n", + "Current heavylight.Table functionality (1.0.5) doesn't carry out validation of table keys before use, i.e. it uses the Band functionality.\n", + "\n", + "This works fine as long as the lookup values are all in the keys (and the keys are sorted in order), however it fails where keys don't exist.\n", + "\n", + "This notebook explores adding a np.isin() test before doing the lookup.\n", + "\n", + "This is likely to be relatively expensive, as the key check only really needs to happen at the start (string keys are from data)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import numba as nb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "table_keys = np.array(['A', 'B', 'C', 'D'], dtype=' np.ndarray:\n", + " ret_arr = -9999 * np.ones(keys.shape[0], dtype=np.int32) # start with sentinel\n", + " for i in nb.prange(keys.shape[0]):\n", + " for j in nb.prange(key_array.shape[0]):\n", + " if keys[i] == key_array[j]:\n", + " ret_arr[i] = j\n", + " # break ## break is commented out as using parallelisation\n", + " return ret_arr" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1, 2, 0, -9999, 3])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_exact(xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def find_exact_py(keys:np.ndarray, key_array:np.ndarray) -> np.ndarray:\n", + " ret_arr = -9999 * np.ones(keys.shape[0], dtype=np.int32) # start with sentinel\n", + " for i in range(keys.shape[0]):\n", + " for j in range(key_array.shape[0]):\n", + " if keys[i] == key_array[j]:\n", + " ret_arr[i] = j\n", + " break\n", + " return ret_arr" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1, 2, 0, -9999, 3], dtype=int32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_exact_py(xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.19 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "82 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact_py(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def find_exact_ss(keys:np.ndarray, key_array:np.ndarray) -> np.ndarray:\n", + " return np.where(np.isin(keys, key_array), \n", + " np.searchsorted(key_array, keys),\n", + " -9999)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1, 2, 0, -9999, 3])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_exact_ss(xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.65 ms ± 33.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact_ss(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def find_exact_unsafe(keys: np.ndarray, key_array:np.ndarray) -> np.ndarray:\n", + " return np.searchsorted(key_array, keys)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.48 ms ± 8.44 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact_unsafe(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1, -9999, -9999, ..., 0, 1, 3])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_exact(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "UnicodeCharSeq(2)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nb.types.UnicodeCharSeq(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# nb.types.Array(nb.types.UnicodeCharSeq(2))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "find_exact (Array(UnicodeCharSeq(2), 1, 'C', False, aligned=True), Array(UnicodeCharSeq(2), 1, 'C', False, aligned=True))\n", + "--------------------------------------------------------------------------------\n", + "# File: /var/folders/n4/gpw_j7653_l052l8phfj8lbr0000gn/T/ipykernel_48347/3325116556.py\n", + "# --- LINE 1 --- \n", + "# label 0\n", + "# keys = arg(0, name=keys) :: array([unichr x 2], 1d, C)\n", + "# keys_shape.0 = getattr(value=keys, attr=shape) :: UniTuple(int64 x 1)\n", + "# keys_size0.1 = static_getitem(value=keys_shape.0, index=0, index_var=None, fn=) :: int64\n", + "# del keys_shape.0\n", + "# key_array = arg(1, name=key_array) :: array([unichr x 2], 1d, C)\n", + "\n", + "@nb.njit(nogil=True, parallel=True)\n", + "\n", + "# --- LINE 2 --- \n", + "\n", + "def find_exact(keys:np.ndarray, key_array:np.ndarray) -> np.ndarray:\n", + "\n", + " # --- LINE 3 --- \n", + " # id=0[LoopNest(index_variable = parfor_index.9, range = (0, keys_size0.1, 1))]{315: }Var(parfor_index.9, 3325116556.py:3)\n", + " # del parfor_index.21\n", + " # del keys\n", + " # del key_array\n", + " # del j\n", + "\n", + " ret_arr = -9999 * np.ones(keys.shape[0], dtype=np.int32) # start with sentinel\n", + "\n", + " # --- LINE 4 --- \n", + "\n", + " for i in nb.prange(keys.shape[0]):\n", + "\n", + " # --- LINE 5 --- \n", + "\n", + " for j in nb.prange(key_array.shape[0]):\n", + "\n", + " # --- LINE 6 --- \n", + "\n", + " if keys[i] == key_array[j]:\n", + "\n", + " # --- LINE 7 --- \n", + "\n", + " ret_arr[i] = j\n", + "\n", + " # --- LINE 8 --- \n", + "\n", + " # break ## break is commented out as using parallelisation\n", + "\n", + " # --- LINE 9 --- \n", + " # $316return_value.1 = cast(value=ret_arr) :: array(int64, 1d, C)\n", + " # del ret_arr\n", + " # return $316return_value.1\n", + "\n", + " return ret_arr\n", + "\n", + "\n", + "================================================================================\n" + ] + } + ], + "source": [ + "find_exact.inspect_types()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "@nb.njit(nogil=True, parallel=True)\n", + "def find_exact_typed(keys:nb.types.UnicodeCharSeq(2)[:], key_array:nb.types.UnicodeCharSeq(2)[:]) -> np.ndarray:\n", + " ret_arr = -9999 * np.ones(keys.shape[0], dtype=np.int32) # start with sentinel\n", + " for i in nb.prange(keys.shape[0]):\n", + " for j in nb.prange(key_array.shape[0]):\n", + " if keys[i] == key_array[j]:\n", + " ret_arr[i] = j\n", + " # break ## break is commented out as using parallelisation\n", + " return ret_arr" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(UnicodeCharSeq(2), 1, 'A', False, aligned=True)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nb.types.UnicodeCharSeq(2)[:]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 1, -9999, -9999, ..., 0, 1, 3])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_exact_typed(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.28 ms ± 44.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact_typed(big_xs, table_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.29 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit find_exact(big_xs, table_keys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Very little difference (I may not have configured it correctly)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, False, False, ..., True, True, True])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table_keys[np.searchsorted(table_keys, big_xs)] == big_xs" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['B', 'AB', 'AB', ..., 'A', 'B', 'D'], dtype=' 1\u001b[0m \u001b[43mtable\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mD\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:202\u001b[0m, in \u001b[0;36mTable.__getitem__\u001b[0;34m(self, keys)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(keys, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 201\u001b[0m keys \u001b[38;5;241m=\u001b[39m keys, \u001b[38;5;66;03m#force to be a tuple\u001b[39;00m\n\u001b[0;32m--> 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:195\u001b[0m, in \u001b[0;36mTable.get\u001b[0;34m(self, *keys)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# TODO: if just one key then this doesn't work? (needs fixed throughout)\u001b[39;00m\n\u001b[0;32m--> 195\u001b[0m int_keys \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[43mmapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapper\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmappers\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_key_table\u001b[38;5;241m.\u001b[39mget_value(\u001b[38;5;241m*\u001b[39mint_keys)\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:195\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# TODO: if just one key then this doesn't work? (needs fixed throughout)\u001b[39;00m\n\u001b[0;32m--> 195\u001b[0m int_keys \u001b[38;5;241m=\u001b[39m [\u001b[43mmapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m key, mapper \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(keys, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)]\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_key_table\u001b[38;5;241m.\u001b[39mget_value(\u001b[38;5;241m*\u001b[39mint_keys)\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:45\u001b[0m, in \u001b[0;36mStringLookup.get\u001b[0;34m(self, keys)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m keys \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstring_vals:\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minvalid string key(s) passed into table lookup.\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39msearchsorted(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstring_vals, keys)\n", + "\u001b[0;31mKeyError\u001b[0m: 'invalid string key(s) passed into table lookup.'" + ] + } + ], + "source": [ + "table['D']" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table[np.array(['A', 'C'])] == ['a', 'c']" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'invalid string key(s) passed into table lookup.'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtable\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mA\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mZ\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:202\u001b[0m, in \u001b[0;36mTable.__getitem__\u001b[0;34m(self, keys)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(keys, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 201\u001b[0m keys \u001b[38;5;241m=\u001b[39m keys, \u001b[38;5;66;03m#force to be a tuple\u001b[39;00m\n\u001b[0;32m--> 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:195\u001b[0m, in \u001b[0;36mTable.get\u001b[0;34m(self, *keys)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# TODO: if just one key then this doesn't work? (needs fixed throughout)\u001b[39;00m\n\u001b[0;32m--> 195\u001b[0m int_keys \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[43mmapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapper\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmappers\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_key_table\u001b[38;5;241m.\u001b[39mget_value(\u001b[38;5;241m*\u001b[39mint_keys)\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:195\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# TODO: if just one key then this doesn't work? (needs fixed throughout)\u001b[39;00m\n\u001b[0;32m--> 195\u001b[0m int_keys \u001b[38;5;241m=\u001b[39m [\u001b[43mmapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m key, mapper \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(keys, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmappers)]\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_key_table\u001b[38;5;241m.\u001b[39mget_value(\u001b[38;5;241m*\u001b[39mint_keys)\n", + "File \u001b[0;32m~/Dev/heavylight/src/heavylight/heavytables.py:42\u001b[0m, in \u001b[0;36mStringLookup.get\u001b[0;34m(self, keys)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(keys, np\u001b[38;5;241m.\u001b[39mndarray):\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mall\u001b[39m(np\u001b[38;5;241m.\u001b[39misin(keys, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstring_vals)):\n\u001b[0;32m---> 42\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minvalid string key(s) passed into table lookup.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m keys \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstring_vals:\n", + "\u001b[0;31mKeyError\u001b[0m: 'invalid string key(s) passed into table lookup.'" + ] + } + ], + "source": [ + "table[np.array(['A', 'Z'])]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "heavylight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index fe63a0d..119faf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,20 @@ requires = ["hatchling"] build-backend = "hatchling.build" +[tool] +rye = { dev-dependencies = [ + "pytest>=8.1.2", + "pytest-cov>=5.0.0", + "pytest-timeout>=2.3.1", + "numpy>=1.26.4", +] } + [tool.hatch.build.targets.wheel] src-dir = "src" [project] name = "heavylight" -version = "1.0.6" +version = "1.0.7" authors = [ { name="Lewis Fogden", email="lewisfogden@gmail.com" }, { name="Matthew Caseres", email="matthewcaseres@outlook.com"} diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..86cb26f --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,35 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false + +-e file:. +coverage==7.5.0 + # via pytest-cov +iniconfig==2.0.0 + # via pytest +numpy==1.26.4 + # via pandas +packaging==24.0 + # via pytest +pandas==2.0.3 + # via heavylight +pluggy==1.5.0 + # via pytest +pytest==8.1.2 + # via pytest-cov + # via pytest-timeout +pytest-cov==5.0.0 +pytest-timeout==2.3.1 +python-dateutil==2.9.0.post0 + # via pandas +pytz==2024.1 + # via pandas +six==1.16.0 + # via python-dateutil +tzdata==2024.1 + # via pandas diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..58fd6b2 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,22 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false + +-e file:. +numpy==1.24.4 + # via pandas +pandas==2.0.3 + # via heavylight +python-dateutil==2.9.0.post0 + # via pandas +pytz==2024.1 + # via pandas +six==1.16.0 + # via python-dateutil +tzdata==2024.1 + # via pandas diff --git a/src/heavylight/__init__.py b/src/heavylight/__init__.py index e4615d1..1f4bedd 100644 --- a/src/heavylight/__init__.py +++ b/src/heavylight/__init__.py @@ -4,4 +4,4 @@ from .memory_optimized_cache import CacheGraph from .make_examples import make_example -__version__ = '1.0.6' +__version__ = '1.0.7' diff --git a/src/heavylight/examples/protection/protection_model_mo.py b/src/heavylight/examples/protection/protection_model_mo.py new file mode 100644 index 0000000..4e7ed66 --- /dev/null +++ b/src/heavylight/examples/protection/protection_model_mo.py @@ -0,0 +1,106 @@ +import heavylight +import numpy as np + +class TermAssurance(heavylight.LightModel): + def t(self, t): + return t + + def net_cf(self, t): + return self.premiums(t) - self.claims(t) - self.expenses(t) + + def premium_pp(self, t): + """monthly premium""" + return self.data["annual_premium"] / 12 + + def sum_assured(self, t): + return np.where( + self.data["shape"] == 'level', + self.data["sum_assured"], + self.decreasing_sa(t) + ) + + def decreasing_sa(self, t): + """sum assured if decreasing, tracking mortgage - interest rate of 7% assumed (could be parameterised)""" + r = (1 + 0.07)**(1/12) - 1 + S = self.data["sum_assured"] + T = self.data["term_y"] * 12 + outstanding = S * ((1 + r)**T - (1 + r)**t)/((1 + r)**T - 1) + return outstanding + + def claim_pp(self, t): + return np.where( + t > self.data["term_y"] * 12, + 0, + self.sum_assured(t) + ) + + def inflation_factor(self, t): + """annual""" + return (1 + self.basis["cost_inflation_pa"])**(t // 12) + + def v(self, t): + """present value of 1 discounted from time t to time 0""" + if t == 0: + return 1 + else: + return self.v(t - 1) / (1 + self.basis["forward_rates"][t]) + + def premiums(self, t): + return self.premium_pp(t) * self.num_pols_if(t) + + def pv_premiums(self, t): + return self.v(t) * self.premiums(t) + + def duration(self, t): + """duration in force in years""" + return t // 12 + + def claims(self, t): + return self.claim_pp(t) * self.num_deaths(t) + + def pv_claims(self, t): + return self.v(t) * self.claims(t) + + def expenses(self, t): + if t == 0: + return self.basis["initial_expense"] + else: + return self.num_pols_if(t) * self.basis["expense_pp"] / 12 * self.inflation_factor(t) + + def pv_expenses(self, t): + return self.v(t) * self.expenses(t) + + def num_pols_if(self, t): + """number of policies in force""" + if t == 0: + return self.data["init_pols_if"] + # if statements need to be vectorised, using np.where (same logic as excel IF()) + else: + return np.where( + t > self.data["term_y"] * 12, + 0, + self.num_pols_if(t - 1) - self.num_exits(t - 1) - self.num_deaths(t - 1) + ) + + def num_exits(self, t): + """exits occurring at time t""" + return self.num_pols_if(t) * (1 - (1 - self.basis["lapse_rate_pa"])**(1/12)) + + def num_deaths(self, t): + """deaths occurring at time t""" + return self.num_pols_if(t) * self.q_x_12(t) + + def age(self, t): + return self.data["age_at_entry"] + t // 12 + + def q_x_12(self, t): + return 1 - (1 - self.q_x_rated(t))**(1/12) + + def q_x(self, t): + return self.basis["mort_table"][self.age(t), self.duration(t), self.data["smoker_status"]] + + def q_x_rated(self, t): + return np.clip(self.q_x(t) * (1 + self.data["extra_mortality"]), 0, 1) + + def commission(self, t): + return 0 diff --git a/src/heavylight/examples/protection/run_model_mo.py b/src/heavylight/examples/protection/run_model_mo.py new file mode 100644 index 0000000..fa0b10e --- /dev/null +++ b/src/heavylight/examples/protection/run_model_mo.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# required libraries: +# heavylight +# pip install heavylight + +# %% +import heavylight +from heavylight import Table +from protection_model_mo import TermAssurance +import numpy as np +import pandas as pd + + +# %% + +basis = { + "cost_inflation_pa": 0.02, + "initial_expense": 500, + "expense_pp": 10, + "lapse_rate_pa": 0.1, + "mort_table": Table.read_csv(r"tables/q_x_generic.csv"), + "forward_rates": Table.read_csv(r"tables/forward_rates.csv"), +} + +def create_data(pols, seed=42): + + rng = np.random.default_rng(seed) + + # override the single datapoint `pols` + data = dict( + sum_assured = rng.integers(10_000, 250_000, pols), + age_at_entry = rng.integers(20, 50, pols), + term_y = rng.integers(10, 30, pols), + smoker_status = rng.choice(['S', 'N'], pols), + shape = rng.choice(['level', 'decreasing'], pols), + annual_premium = np.ones(pols), + init_pols_if = np.ones(pols), + extra_mortality = np.zeros(pols), + sex = rng.choice(['F', 'M'], pols), + ) + return data + +## with LightModel we can run unoptimised, then run optimised. + +opt_data = create_data(5) # small data set to run the optimiser +data = create_data(100_000) # full data set + +proj = TermAssurance() +proj.data = opt_data +proj.basis = basis +proj.RunModel(240) # run unoptimised +proj.data = data # replace data +proj.RunOptimized() # run optimised +proj.df_agg +# %% + diff --git a/src/heavylight/heavytables.py b/src/heavylight/heavytables.py index 2864206..a90c0cd 100644 --- a/src/heavylight/heavytables.py +++ b/src/heavylight/heavytables.py @@ -32,6 +32,19 @@ def __init__(self, lower, upper): def get(self, numpy_array): return np.clip(numpy_array, self.lower, self.upper) +class StringLookup: + def __init__(self, string_vals): + self.string_vals = np.array(string_vals) + + def get(self, keys): + if isinstance(keys, np.ndarray): + if not all(np.isin(keys, self.string_vals)): + raise KeyError("invalid string key(s) passed into table lookup.") + else: + if not keys in self.string_vals: + raise KeyError("invalid string key(s) passed into table lookup.") + return np.searchsorted(self.string_vals, keys) + class BandLookup: def __init__(self, upper_bounds, labels): """Inputs must be sorted""" @@ -152,7 +165,13 @@ def __init__(self, df:pd.DataFrame, rectify=False, safe=True): lower = df[col].min() upper = df[col].max() self.mappers.append(BoundIntLookup(lower=lower, upper=upper)) - elif col_type in ["str", "band"]: + elif col_type == "str": + cols = df[col].unique() + string_mapper = StringLookup(cols) + self.mappers.append(string_mapper) + df_int_keys[col] = string_mapper.get(df_int_keys[col]) + + elif col_type in ["str_unsafe", "band"]: df_col = pd.DataFrame(df[col].unique(), columns=["band_name"]).reset_index().sort_values("band_name") # add a nan on the end so we get errors if the lookup fails diff --git a/tests/test_heavytables.py b/tests/test_heavytables.py index 7874628..415eb30 100644 --- a/tests/test_heavytables.py +++ b/tests/test_heavytables.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import heavylight.heavytables as ht +import pytest def test_integer_lookup(): lookup = ht.IntegerLookup() @@ -15,4 +16,15 @@ def test_bound_integer_lookup(): def test_rectify(): df1 = pd.DataFrame({'a': [1, 2], 'b': [1, 2], 'c': [1, 2]}) expected_rectified_df1 = pd.DataFrame({'a': [1, 1, 2, 2], 'b': [1, 2, 1, 2], 'c': [1, np.nan, np.nan, 2]}) - assert ht.Table.rectify(df1).equals(expected_rectified_df1) \ No newline at end of file + assert ht.Table.rectify(df1).equals(expected_rectified_df1) + +def test_string(): + df = pd.DataFrame({'key|str': ['A', 'B', 'C'], 'val|str': ['a', 'b', 'c']}) + table = ht.Table(df) + assert table['A'] == 'a' + assert table['C'] == 'c' + with pytest.raises(KeyError, match=r"'invalid string key\(s\) passed into table lookup.'"): + table['AB'] + assert list(table[np.array(['A', 'C'])]) == ['a', 'c'] + with pytest.raises(KeyError, match=r"'invalid string key\(s\) passed into table lookup.'"): + table[np.array(['A', 'AB'])]