diff --git a/examples/product-recommender/lancedb_cloud/README.md b/examples/product-recommender/lancedb_cloud/README.md
deleted file mode 100644
index 368725cc..00000000
--- a/examples/product-recommender/lancedb_cloud/README.md
+++ /dev/null
@@ -1,42 +0,0 @@
-# Product Recommender using Collaborative Filtering and LanceDB
-
-Use LanceDB and collaborative filtering to recommend products based on a user's past buying history. We used the Instacart dataset as our data for this example.
-Colab walkthrough -
-
-### Get dataset
-To run this example, please download the dataset from our s3 bucket: http://vectordb-recipes.s3.us-west-2.amazonaws.com/product-recommender.zip
-!!!This example needs to be run on GPU otherwise it will be very slow.
-It covers how to create a LanceDB table remotely, how to create an index on the vector column to accelerate search, followed by search on the remote table where results are saved as a pandas Dataframe.
-
-```
-wget http://vectordb-recipes.s3.us-west-2.amazonaws.com/product-recommender.zip
-unzip product-recommender.zip
-cp product-recommender/*.zip .
-rm -fr product-recommender
-```
-
-### Set credentials
-if you would like to set api key through an environment variable:
-```
-export LANCEDB_API_KEY="sk_..."
-```
-
-replace the following lines in main.py with your project slug and api key"
-```
-db_url = "db://your-project-name"
- api_key="sk_..."
-```
-
-Run the script
-```python
-python main.py
-```
-
-| Argument | Default Value | Description |
-|---|---|---|
-| factors | 100 | dimension of latent factor vectors |
-| regularization | 0.05 | strength of penalty term |
-| iterations | 50 | number of iterations to update |
-| num-threads | 1 | amount of parallelization |
-| num-partitions | 256 | number of partitions of the index |
-| num-sub-vectors | 16 | number of sub-vectors (M) that will be created during Product Quantization (PQ) |
diff --git a/examples/product-recommender/lancedb_cloud/main.ipynb b/examples/product-recommender/lancedb_cloud/main.ipynb
deleted file mode 100644
index 822ef0da..00000000
--- a/examples/product-recommender/lancedb_cloud/main.ipynb
+++ /dev/null
@@ -1,3333 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "YmdWGrw4t5G2"
- },
- "source": [
- "# Product Recommender using Collaborative Filtering and LanceDB\n",
- "\n",
- "We are going to use **LanceDB** and **Collaborative Filtering** to recommend products based on a user's past buying history. We used the **Instacart dataset** as our data for this example.\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Credentials\n",
- "\n",
- "Copy and paste the project name and the api key from your project page.\n",
- "These will be used later to [connect to LanceDB Cloud](#scroll-to=5q8m6GMD7sGu)"
- ],
- "metadata": {
- "id": "sCtHNvkbzSot"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "project_slug = \"your-project-slug\" # @param {type:\"string\"}"
- ],
- "metadata": {
- "id": "zpPM2T8zzZkw"
- },
- "execution_count": 2,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "api_key = \"sk_...\" # @param {type:\"string\"}"
- ],
- "metadata": {
- "id": "xgCqtc99zwUQ"
- },
- "execution_count": 3,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "You can also set the LANCEDB_API_KEY as an environment variable with one of the options below"
- ],
- "metadata": {
- "id": "eEITDnEczz7G"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "!export LANCEDB_API_KEY=\"sk_...\""
- ],
- "metadata": {
- "id": "Md5kS8s7z0-j"
- },
- "execution_count": 3,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "import os\n",
- "import getpass\n",
- "os.environ[\"LANCEDB_API_KEY\"] = getpass.getpass(\"Enter Your LANCEDB API Key:\")"
- ],
- "metadata": {
- "id": "d7gq19Wez3JZ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9-fnXVuO8XQ0"
- },
- "source": [
- "## Get dataset\n",
- "Download and unzip the dataset from LanceDB s3 bucket."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "3jXSVspr7sGe",
- "vscode": {
- "languageId": "shellscript"
- },
- "outputId": "4c09916d-85de-46d6-9c16-ed6746ac4e19",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "--2024-01-23 03:30:37-- http://vectordb-recipes.s3.us-west-2.amazonaws.com/product-recommender.zip\n",
- "Resolving vectordb-recipes.s3.us-west-2.amazonaws.com (vectordb-recipes.s3.us-west-2.amazonaws.com)... 3.5.84.12, 3.5.84.155, 3.5.84.131, ...\n",
- "Connecting to vectordb-recipes.s3.us-west-2.amazonaws.com (vectordb-recipes.s3.us-west-2.amazonaws.com)|3.5.84.12|:80... connected.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 411510857 (392M) [application/zip]\n",
- "Saving to: ‘product-recommender.zip’\n",
- "\n",
- "product-recommender 100%[===================>] 392.45M 22.5MB/s in 19s \n",
- "\n",
- "2024-01-23 03:30:56 (20.8 MB/s) - ‘product-recommender.zip’ saved [411510857/411510857]\n",
- "\n",
- "Archive: product-recommender.zip\n",
- " creating: product-recommender/\n",
- " inflating: __MACOSX/._product-recommender \n",
- " inflating: product-recommender/order_products__prior.csv.zip \n",
- " inflating: __MACOSX/product-recommender/._order_products__prior.csv.zip \n",
- " inflating: product-recommender/order_products__train.csv.zip \n",
- " inflating: __MACOSX/product-recommender/._order_products__train.csv.zip \n",
- " inflating: product-recommender/orders.csv.zip \n",
- " inflating: __MACOSX/product-recommender/._orders.csv.zip \n",
- " inflating: product-recommender/products.csv.zip \n",
- " inflating: __MACOSX/product-recommender/._products.csv.zip \n",
- " inflating: product-recommender/instacart-market-basket-analysis.zip \n",
- " inflating: __MACOSX/product-recommender/._instacart-market-basket-analysis.zip \n"
- ]
- }
- ],
- "source": [
- "!wget http://vectordb-recipes.s3.us-west-2.amazonaws.com/product-recommender.zip\n",
- "!unzip product-recommender.zip\n",
- "!cp product-recommender/*.zip .\n",
- "!rm -fr product-recommender"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xVLHZB8BzJQG"
- },
- "source": [
- "Install dependencies:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "R3_Hq2VC4_zT",
- "outputId": "fc920fc5-ac48-48e6-a2b2-0f84d4436ef7"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.23.5)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n",
- "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)\n",
- "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)\n",
- "Collecting implicit\n",
- " Downloading implicit-0.7.2-cp310-cp310-manylinux2014_x86_64.whl (8.9 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n",
- "Collecting lancedb\n",
- " Downloading lancedb-0.5.0-py3-none-any.whl (87 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)\n",
- "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n",
- "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2023.11.17)\n",
- "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)\n",
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.1)\n",
- "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.1)\n",
- "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n",
- "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n",
- "Requirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from implicit) (3.2.0)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n",
- "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n",
- "Collecting deprecation (from lancedb)\n",
- " Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n",
- "Collecting pylance==0.9.6 (from lancedb)\n",
- " Downloading pylance-0.9.6-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.6/18.6 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n",
- " Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n",
- "Collecting retry>=0.9.2 (from lancedb)\n",
- " Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n",
- "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (1.10.13)\n",
- "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n",
- "Collecting semver>=3.0 (from lancedb)\n",
- " Downloading semver-3.0.2-py3-none-any.whl (17 kB)\n",
- "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.2)\n",
- "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n",
- "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n",
- "Collecting overrides>=0.7 (from lancedb)\n",
- " Downloading overrides-7.6.0-py3-none-any.whl (17 kB)\n",
- "Collecting pyarrow>=12 (from pylance==0.9.6->lancedb)\n",
- " Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.6)\n",
- "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n",
- "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n",
- " Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n",
- "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n",
- "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
- "Installing collected packages: ratelimiter, semver, pyarrow, py, overrides, deprecation, retry, pylance, implicit, lancedb\n",
- " Attempting uninstall: pyarrow\n",
- " Found existing installation: pyarrow 10.0.1\n",
- " Uninstalling pyarrow-10.0.1:\n",
- " Successfully uninstalled pyarrow-10.0.1\n",
- "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
- "ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 15.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
- "\u001b[0mSuccessfully installed deprecation-2.1.0 implicit-0.7.2 lancedb-0.5.0 overrides-7.6.0 py-1.11.0 pyarrow-15.0.0 pylance-0.9.6 ratelimiter-1.2.0.post0 retry-0.9.2 semver-3.0.2\n"
- ]
- }
- ],
- "source": [
- "!pip install numpy pandas scipy kaggle implicit torch lancedb"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "i_eatRhaIGIz"
- },
- "source": [
- "First, let's import all the required modules for this example."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "emp_MSXZt5G8"
- },
- "outputs": [],
- "source": [
- "import zipfile\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import scipy.sparse\n",
- "import torch\n",
- "import implicit\n",
- "from implicit import evaluation\n",
- "import pydantic\n",
- "import lancedb\n",
- "from lancedb.pydantic import pydantic_to_schema, vector"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "K4Q4cOX-4_zY"
- },
- "source": [
- "We must now extract the zip files."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "id": "f3g296nL4_zZ"
- },
- "outputs": [],
- "source": [
- "files = [\n",
- " 'instacart-market-basket-analysis.zip',\n",
- " 'order_products__train.csv.zip',\n",
- " 'order_products__prior.csv.zip',\n",
- " 'products.csv.zip',\n",
- " 'orders.csv.zip'\n",
- "]\n",
- "\n",
- "for filename in files:\n",
- " with zipfile.ZipFile(filename, 'r') as zip_ref:\n",
- " zip_ref.extractall('./')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oLgkRIfq4_zZ"
- },
- "source": [
- "Now we can move on to loading the dataset. We'll first read the csv files and create dataframes."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "cBbbR7Rut5G_"
- },
- "outputs": [],
- "source": [
- "products = pd.read_csv('products.csv')\n",
- "orders = pd.read_csv('orders.csv')\n",
- "order_products = pd.concat([pd.read_csv('order_products__train.csv'), pd.read_csv('order_products__prior.csv')])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5FV_GGjst5HA"
- },
- "source": [
- "Since there isn't a user rating attribute, we'll gather \"confidence\" data by looking at the frequency of each item purchased by a user, and store this in the `data` dataframe."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "ZjRh7RYpt5HB"
- },
- "outputs": [],
- "source": [
- "customer_order_products = pd.merge(orders, order_products, how='inner',on='order_id')\n",
- "\n",
- "# create confidence table\n",
- "data = customer_order_products.groupby(['user_id', 'product_id'])[['order_id']].count().reset_index()\n",
- "data.columns=[\"user_id\", \"product_id\", \"total_orders\"]\n",
- "data.product_id = data.product_id.astype('int64')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "77lvwm0St5HC"
- },
- "source": [
- "Let's create a couple of test users to examine the recommendations later:\n",
- "- 1st test user: buys 50 sodas: **Zero Calorie Cola**\n",
- "- 2nd test user: buys organic produce: **Organic Whole Milk** and **Organic Blackberries**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 224
- },
- "id": "A06EfAf-t5HC",
- "outputId": "af9c06f5-1cbd-4ee1-9876-c62591fe95bd"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "13863749\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- " user_id product_id total_orders\n",
- "13863744 206209 48697 1\n",
- "13863745 206209 48742 2\n",
- "13863746 206210 46149 50\n",
- "13863747 206211 27845 49\n",
- "13863748 206211 26604 32"
- ],
- "text/html": [
- "\n",
- "
\n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " user_id | \n",
- " product_id | \n",
- " total_orders | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 13863744 | \n",
- " 206209 | \n",
- " 48697 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 13863745 | \n",
- " 206209 | \n",
- " 48742 | \n",
- " 2 | \n",
- "
\n",
- " \n",
- " 13863746 | \n",
- " 206210 | \n",
- " 46149 | \n",
- " 50 | \n",
- "
\n",
- " \n",
- " 13863747 | \n",
- " 206211 | \n",
- " 27845 | \n",
- " 49 | \n",
- "
\n",
- " \n",
- " 13863748 | \n",
- " 206211 | \n",
- " 26604 | \n",
- " 32 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {},
- "execution_count": 11
- }
- ],
- "source": [
- "data_new = pd.DataFrame([[data.user_id.max() + 1, 46149, 50],\n",
- " [data.user_id.max() + 2, 27845, 49],\n",
- " [data.user_id.max() + 2, 26604, 32]\n",
- " ], columns=['user_id', 'product_id', 'total_orders'])\n",
- "data = pd.concat([data, data_new]).reset_index(drop = True)\n",
- "data.tail()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xBC-8PFTt5HD"
- },
- "source": [
- "In the next step, we will extract user and product unique ids, in order to create a CSR (Compressed Sparse Row) matrix. This will allow us to perform collaborative filtering.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "id": "v2_2R7zmt5HE"
- },
- "outputs": [],
- "source": [
- "# extract unique user and product ids\n",
- "unique_users = list(np.sort(data.user_id.unique()))\n",
- "unique_products = list(np.sort(products.product_id.unique()))\n",
- "purchases = list(data.total_orders)\n",
- "\n",
- "# create zero-based index position <-> user/item ID mappings\n",
- "index_to_user = pd.Series(unique_users)\n",
- "\n",
- "# create reverse mappings from user/item ID to index positions\n",
- "user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)\n",
- "\n",
- "# create row and column for user and product ids\n",
- "users_rows = data.user_id.astype(int)\n",
- "products_cols = data.product_id.astype(int)\n",
- "\n",
- "# create CSR matrix\n",
- "matrix = scipy.sparse.csr_matrix((purchases, (users_rows, products_cols)), shape=(len(unique_users) + 1, len(unique_products) + 1))\n",
- "matrix.data = np.nan_to_num(matrix.data, copy=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "II6wOH96t5HF"
- },
- "source": [
- "Let's now create a recommender model using the **implicit** library. The recommendation model is based off the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](https://www.researchgate.net/publication/220765111_Collaborative_Filtering_for_Implicit_Feedback_Datasets) with performance optimizations described in [Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.379.6473&rep=rep1&type=pdf).\n",
- "\n",
- "Note: this step will take about 17 minutes with the current parameter setup."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 105,
- "referenced_widgets": [
- "2c0101b0a3574a14b2a37fc431eb2908",
- "31c3c90fa42f489796fba11d57799089",
- "e13993dda2da40ff806d6e31a6e987d3",
- "0bff70b647f3404fa15690ec9f3d0c78",
- "674cf2d29d044cada59480813e0e8e58",
- "bfd4ff099ed14ab1bd79233beea7f402",
- "000f9e8fd1db4bc0a7aceeb822ca2b2e",
- "75b270d981de425ba1fd9a790b2a68ff",
- "baafe1d810594384af1a5ffa4f2f5cb4",
- "bf95fd811f79425bb2248525aeab7da0",
- "46fb5083adf24ce4ae3fd4ea9aa4772e"
- ]
- },
- "id": "k0GW99kxt5HF",
- "outputId": "d3e22ae9-ff96-4d89-f0aa-c3b5cd47d354"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 2 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n",
- " check_blas_config()\n"
- ]
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " 0%| | 0/50 [00:00, ?it/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "2c0101b0a3574a14b2a37fc431eb2908"
- }
- },
- "metadata": {}
- }
- ],
- "source": [
- "#split data into train and test splits\n",
- "train, test = evaluation.train_test_split(matrix, train_percentage=0.9)\n",
- "\n",
- "# initialize the recommender model\n",
- "model = implicit.als.AlternatingLeastSquares(factors=128,\n",
- " regularization=0.05,\n",
- " iterations=50,\n",
- " num_threads=1)\n",
- "\n",
- "alpha = 15\n",
- "train = (train * alpha).astype('double')\n",
- "\n",
- "# train the model on CSR matrix\n",
- "model.fit(train, show_progress = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yN80hSojt5HF"
- },
- "source": [
- "## Let's now evaluate the model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 120,
- "referenced_widgets": [
- "5b98b7b242994c999064688c9210c61b",
- "d5b1eb34ddc949aebd25b3744b93b726",
- "752d37b9a68b42d284493645962f3782",
- "f0def002c7ca41f6a70e9dba1bc605c7",
- "4b0298a9ecf84b509fbf379d43339b9c",
- "a37be209d5bb44e18f32c0259073d2c8",
- "b35984b48d8847eea119ee5eda049b9d",
- "4b20ad4b356645bbbfb94929160943f2",
- "63b8646c732246988f566d0442a070e8",
- "ae8581ec76314304b2078759e1dbdd7e",
- "d0e90066f1ec42afa5f1c02551d3889e"
- ]
- },
- "id": "BbD8of_nt5HG",
- "outputId": "547c4171-d89f-4d3f-87f6-3e99cb22586f"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " 0%| | 0/192802 [00:00, ?it/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "5b98b7b242994c999064688c9210c61b"
- }
- },
- "metadata": {}
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "{'precision': 0.2742377453615933,\n",
- " 'map': 0.04506404325620732,\n",
- " 'ndcg': 0.1449554399501384,\n",
- " 'auc': 0.6549935260418878}"
- ]
- },
- "metadata": {},
- "execution_count": 15
- }
- ],
- "source": [
- "test = (test * alpha).astype('double')\n",
- "evaluation.ranking_metrics_at_k(model, train, test, K=100,\n",
- " show_progress=True, num_threads=1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LNmva3Dlt5HG"
- },
- "source": [
- "From the model, we'll be able to retrieve item and user factors, which we can use later on to store in LanceDB as vector embeddings."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "JUtCROQKt5HG",
- "outputId": "25c417a4-30e3-4923-da78-c372e70d28c5"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "array([[ 4.18832153e-03, 3.25558195e-03, -1.20758591e-02,\n",
- " 1.40742492e-03, -9.09519568e-03, 3.18243494e-03,\n",
- " 2.07483694e-02, -3.95777356e-03, -7.84489443e-04,\n",
- " 1.28329173e-03, 4.66100639e-03, 1.26599418e-02,\n",
- " 1.69202778e-02, -3.54033429e-03, -1.87805621e-04,\n",
- " -8.05972423e-03, 4.04613744e-03, 7.47162709e-03,\n",
- " 4.05248860e-03, 1.68309249e-02, -1.78848747e-02,\n",
- " -9.86590981e-03, 8.46584328e-03, -1.20693864e-02,\n",
- " 7.22488947e-03, 3.90211469e-03, 6.32435898e-04,\n",
- " 3.13967327e-03, 9.04218480e-03, 2.50183023e-03,\n",
- " 1.39820874e-02, 7.54051283e-03, 1.57470535e-02,\n",
- " 4.96101473e-03, 1.74571313e-02, 4.82573919e-03,\n",
- " 1.31175248e-02, 2.78141089e-02, 2.54594497e-02,\n",
- " 1.70677726e-04, 6.35464117e-03, -3.27711529e-03,\n",
- " 8.61203857e-03, 1.61729436e-02, -7.27234699e-04,\n",
- " 7.29484204e-03, -6.27670763e-03, 2.42914446e-02,\n",
- " 9.70306620e-03, 9.60955396e-03, 1.76130934e-03,\n",
- " 1.24175642e-02, 1.61149055e-02, -6.19298825e-03,\n",
- " 1.43120736e-02, 8.98846332e-03, -4.45187604e-03,\n",
- " -1.01331789e-02, 1.13288751e-02, 5.21639129e-03,\n",
- " -2.32453570e-02, -9.21340834e-04, 1.41203729e-02,\n",
- " 1.15836377e-03, 9.21401940e-03, 1.86691377e-02,\n",
- " -1.45641970e-03, 3.42004225e-02, 4.21455083e-03,\n",
- " 1.72144044e-02, 6.25161314e-03, 1.53229507e-02,\n",
- " 1.02525502e-02, 3.70174204e-03, -3.06739035e-04,\n",
- " 4.36588563e-03, 9.17611178e-03, 2.26073209e-02,\n",
- " 4.50356351e-03, 7.92219583e-03, 9.34277428e-04,\n",
- " 1.91239640e-02, -1.67676080e-02, 4.76368004e-03,\n",
- " 6.63227355e-03, -5.15057752e-03, 1.04246605e-02,\n",
- " 1.05045931e-02, 2.13206583e-03, 8.84506665e-03,\n",
- " -3.37255420e-03, -6.84900908e-03, -4.62881243e-03,\n",
- " 8.68821703e-03, 5.13017131e-03, 5.22500556e-03,\n",
- " -9.12018027e-03, -6.31605508e-03, 6.93989592e-03,\n",
- " 2.04393896e-03, -1.66683702e-03, 7.34541751e-03,\n",
- " 1.54855782e-02, -2.50343612e-04, 3.87350516e-03,\n",
- " 1.11501506e-02, 1.94554869e-02, 3.02761160e-02,\n",
- " 5.73130697e-03, -3.03466641e-03, 8.57606344e-03,\n",
- " 9.56064463e-03, 9.24304873e-03, -1.49936741e-02,\n",
- " -6.85681123e-03, 1.99363139e-02, -4.29221604e-04,\n",
- " -5.85102988e-03, -2.01355782e-03, 1.39436489e-02,\n",
- " -5.09022153e-04, 7.93045852e-03, -2.93425820e-03,\n",
- " 1.70512926e-02, 3.72680346e-03, 4.26774239e-03,\n",
- " 1.29361469e-02, 3.41003831e-03],\n",
- " [ 4.08880366e-03, 1.89150311e-03, 3.25225573e-03,\n",
- " 5.50956652e-03, 4.17970167e-03, 1.52355502e-03,\n",
- " 3.83031485e-03, 3.52009456e-03, 2.86640553e-03,\n",
- " 4.81489720e-03, 3.90547770e-03, 5.25039481e-03,\n",
- " 8.52285326e-03, 2.83156661e-03, 7.00753042e-03,\n",
- " 4.67074849e-03, 5.77870058e-03, 3.62071581e-03,\n",
- " 4.98738885e-03, 1.30909227e-03, 6.40545553e-03,\n",
- " 5.35790483e-03, 7.04027340e-03, 4.54069860e-03,\n",
- " 4.93164733e-03, 2.20916839e-03, 4.92953369e-03,\n",
- " 5.04408404e-03, 2.08156300e-03, 5.32587618e-03,\n",
- " 4.29942692e-03, 5.37325954e-03, 3.32720438e-03,\n",
- " 7.78398663e-03, 2.72745849e-03, 5.18748770e-03,\n",
- " 6.30498864e-03, 5.85784856e-03, 4.62009897e-03,\n",
- " 6.24990417e-03, 4.08851821e-03, 4.49793646e-03,\n",
- " 7.78977934e-04, 2.64118239e-03, 2.32547079e-03,\n",
- " 5.02325455e-03, 6.91512600e-03, 4.60041454e-03,\n",
- " 6.66597480e-05, 5.87717863e-03, 4.27115988e-03,\n",
- " 4.28729318e-03, 1.13794568e-03, 7.68032717e-03,\n",
- " 5.33338822e-03, 6.90902770e-03, 5.38264960e-03,\n",
- " 5.93157578e-03, 4.84365830e-03, 4.92752390e-03,\n",
- " 1.62087195e-03, 7.48377480e-03, 3.89479683e-03,\n",
- " -5.76462335e-05, 1.03033381e-02, 3.63176106e-03,\n",
- " 4.49880911e-03, 4.64092754e-03, 1.38480240e-03,\n",
- " 4.81152860e-03, 5.39690442e-03, 4.84804343e-03,\n",
- " 3.47388530e-04, 7.04673876e-04, 6.95901597e-03,\n",
- " 7.98352994e-03, 2.47756205e-03, 1.70948007e-03,\n",
- " 5.22315735e-03, 2.06266297e-03, 1.11589418e-03,\n",
- " 1.01095904e-03, 2.19165138e-03, -9.10140574e-04,\n",
- " 7.64639908e-03, 5.72459772e-03, 4.89675207e-03,\n",
- " 1.48792891e-03, 2.68044509e-03, 6.07493240e-03,\n",
- " 5.42714074e-03, 7.35473679e-03, 3.19598289e-03,\n",
- " 3.64008965e-03, 1.87583105e-03, 4.48295055e-03,\n",
- " 2.47131498e-03, 3.09168128e-03, 4.25936468e-03,\n",
- " 2.27378379e-03, 2.08440656e-03, 6.94426883e-04,\n",
- " 2.01272778e-03, 2.77051283e-03, 5.01386821e-03,\n",
- " 5.31353708e-03, 1.90395059e-03, 2.16349540e-03,\n",
- " 4.04190738e-03, 4.96644387e-03, 1.97983976e-03,\n",
- " 9.15821642e-04, 3.11542186e-03, 3.71921458e-03,\n",
- " 2.56881723e-03, 5.01005258e-03, 4.94958553e-03,\n",
- " 2.06254027e-03, 4.21693781e-03, 6.14025909e-03,\n",
- " 5.64814592e-03, 1.09314881e-02, 4.46141372e-03,\n",
- " 3.37589253e-03, 7.11428293e-04, 3.79333482e-03,\n",
- " 3.88169941e-03, 4.75861132e-03]], dtype=float32)"
- ]
- },
- "metadata": {},
- "execution_count": 17
- }
- ],
- "source": [
- "model.item_factors[1:3]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "O3onbJmnt5HG",
- "outputId": "13740e14-6dd6-498a-e307-5b3eed4d1eb1"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "array([[-0.48312342, -0.16332878, -0.27058715, -0.68734646, 0.55745304,\n",
- " -0.76024646, 1.3025886 , -1.1410682 , 0.19876784, 0.322232 ,\n",
- " 1.418613 , -0.35110232, -0.20965634, 0.06050462, -1.2792661 ,\n",
- " -1.0213155 , 0.4870829 , 0.1747867 , -0.56089026, 1.9309798 ,\n",
- " -1.1751343 , -1.7791682 , -1.1694795 , 0.05588444, 1.1789317 ,\n",
- " 0.46748516, -1.4641706 , -0.34146857, 0.38970897, 0.8604016 ,\n",
- " 0.3465701 , 1.1880745 , 0.06135967, -1.3244237 , 0.3275966 ,\n",
- " -1.1865908 , -0.01917509, 2.7532892 , 2.7307365 , 0.44283357,\n",
- " 0.5644037 , -0.697197 , -1.8847649 , 0.10031813, 0.3599322 ,\n",
- " -0.83181113, -1.9561976 , 0.8480924 , 0.910125 , -0.35006854,\n",
- " 0.45438412, 1.1324192 , 0.02506897, 0.7978778 , -1.0787288 ,\n",
- " 0.41879764, -1.0015563 , -0.11314881, -1.512127 , -0.37960863,\n",
- " -0.5743517 , -1.0606588 , 0.9415234 , 0.1189226 , -0.10419434,\n",
- " 1.4429063 , -0.35251117, 0.59351844, 0.5283425 , -0.24646994,\n",
- " -0.48999467, 1.0533476 , 0.28534362, 0.74745566, 0.26966977,\n",
- " 0.01470857, 0.5190429 , 0.85178673, -0.62364656, -0.44840345,\n",
- " -0.6985944 , 1.7859677 , -0.9912727 , 0.88918775, 0.61314136,\n",
- " 1.3294568 , 1.7689328 , -0.42922932, -0.27359295, 1.8145771 ,\n",
- " -0.05140882, -0.72702384, -0.11391591, -0.1860256 , 0.7310641 ,\n",
- " -0.7768954 , -0.3302253 , 0.150209 , -0.60365665, 0.24954513,\n",
- " -0.2766658 , 0.01893546, 0.3570815 , 0.18330622, -0.89038587,\n",
- " 0.50650024, 1.0074087 , 1.7643334 , 1.5506059 , -0.38804454,\n",
- " -0.45902696, -0.3882332 , -0.58766186, 0.30682987, -0.45430216,\n",
- " 0.17607969, 0.6972072 , -0.3375235 , -1.6623874 , 0.05010271,\n",
- " -1.246921 , 1.4658022 , -1.158234 , -0.42433274, 0.49941427,\n",
- " -1.1462147 , 1.3886684 , 1.3426281 ],\n",
- " [-0.48055026, -1.076108 , 1.2871186 , 0.73388743, 1.1587979 ,\n",
- " -0.61240053, -1.1271679 , 1.5407826 , -1.0408585 , 0.6814867 ,\n",
- " -0.05775254, 0.36426723, -1.6217808 , 0.3340878 , -1.076462 ,\n",
- " -0.44586924, 1.0720152 , 0.8573093 , -0.81757593, -1.3212438 ,\n",
- " -1.4259018 , 0.8028897 , 0.727854 , -0.72402936, -0.26787922,\n",
- " 0.4334872 , 3.0854182 , -0.903931 , 0.3117463 , 1.932017 ,\n",
- " 1.743012 , -0.08208363, -1.1798037 , -1.4148307 , -0.03076403,\n",
- " 1.3006622 , -1.5442777 , 0.5676142 , -0.755088 , 2.4009585 ,\n",
- " 0.33378768, -1.1779053 , -0.11361812, -0.46143544, 1.6553828 ,\n",
- " 0.31190038, -2.1039965 , -0.903235 , 2.319655 , -3.0109007 ,\n",
- " -1.284968 , 0.6581418 , 0.40891904, 0.57213986, -2.1724799 ,\n",
- " -1.4901172 , -0.10466211, 0.82121205, 0.0346746 , -0.4013229 ,\n",
- " 0.8444738 , -0.9185106 , 1.9658837 , 1.9450268 , -1.6841023 ,\n",
- " 2.7010896 , 1.1157808 , 0.06317325, 0.4229485 , -0.94922143,\n",
- " -1.4750186 , -1.0483259 , 3.7233133 , 1.9119471 , -0.5080464 ,\n",
- " 0.4889877 , 0.48215535, -0.35629106, -1.8599209 , -1.0194218 ,\n",
- " 0.11349088, 1.1718806 , 1.3258948 , 1.0701228 , -2.3570247 ,\n",
- " -0.42508158, 0.04244204, -1.3229184 , -0.7360056 , 0.05403712,\n",
- " 1.6118884 , 1.5898055 , 1.5195148 , -1.1609313 , 0.43079212,\n",
- " -1.3221414 , 0.17119163, 1.4561695 , 0.8667575 , 0.02400587,\n",
- " -0.55747974, 0.16746764, 1.7400613 , 0.88008255, -0.6901739 ,\n",
- " 0.4686606 , 2.7078378 , 2.7286143 , -0.52630275, -1.3082739 ,\n",
- " 3.9579751 , 0.2908509 , 2.0343082 , -0.05273173, 1.4064884 ,\n",
- " -1.2191583 , 1.6978588 , 2.9528291 , 0.35665286, -1.6854041 ,\n",
- " -3.23004 , 0.20751497, -2.429357 , 2.0009892 , -0.6266644 ,\n",
- " 0.736535 , -1.2620703 , -0.16571261]], dtype=float32)"
- ]
- },
- "metadata": {},
- "execution_count": 18
- }
- ],
- "source": [
- "model.user_factors[1:3]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "38rssdYCBR4E"
- },
- "source": [
- "## Let's save the data and create a empty LanceDB Table using a Pydantic model.\n",
- "A Table is designed to store large numbers of columns and huge quantities of data! For those interested, a LanceDB is columnar-based, and uses Lance, an open data format to store data."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "id": "3_ykVLT6t5HH"
- },
- "outputs": [],
- "source": [
- "# connect to LanceDB Cloud with previously set credentials\n",
- "uri = \"db://\" + project_slug\n",
- "db = lancedb.connect(uri, api_key=api_key, region=\"us-east-1\")\n"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "data.head()"
- ],
- "metadata": {
- "id": "9YiqyzadgiQl",
- "outputId": "df0e60c3-eef5-4a1f-efe5-2f0d927a38d4",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 206
- }
- },
- "execution_count": 21,
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- " user_id product_id total_orders\n",
- "0 1 196 11\n",
- "1 1 10258 10\n",
- "2 1 10326 1\n",
- "3 1 12427 10\n",
- "4 1 13032 4"
- ],
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " user_id | \n",
- " product_id | \n",
- " total_orders | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 1 | \n",
- " 196 | \n",
- " 11 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 1 | \n",
- " 10258 | \n",
- " 10 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1 | \n",
- " 10326 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 1 | \n",
- " 12427 | \n",
- " 10 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 1 | \n",
- " 13032 | \n",
- " 4 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {},
- "execution_count": 21
- }
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "id": "ufHsF0o4t5HI"
- },
- "outputs": [],
- "source": [
- "class ProductModel(pydantic.BaseModel):\n",
- " product_id: int\n",
- " product_name: str\n",
- " vector: vector(128)\n",
- "schema = pydantic_to_schema(ProductModel)\n",
- "table_name = 'product_recommender'\n",
- "db.drop_table(table_name)\n",
- "try:\n",
- " tbl = db.create_table(table_name, schema=schema)\n",
- "except:\n",
- " tbl = db.open_table(table_name)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0-2K-g4-t5HJ"
- },
- "source": [
- "Let's now store our item factors into the table via the vector column of `product_entries`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {
- "id": "NOOPF9zOt5HJ"
- },
- "outputs": [],
- "source": [
- "# Transform items into factors\n",
- "items_factors = model.item_factors\n",
- "product_entries = products[['product_id', 'product_name']].drop_duplicates()\n",
- "product_entries['product_id'] = product_entries.product_id.astype('int64')\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "item_embeddings = items_factors[1:].tolist()\n",
- "product_entries['vector'] = item_embeddings\n",
- "\n",
- "tbl.add(product_entries)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "j3aU4z-tSbWE"
- },
- "source": [
- "## Let's create an ANN index in order to speed up retrieval. This might take a while."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "id": "H8HyvjCFSeaz",
- "outputId": "27519f2a-e95a-4442-97b1-291931180ca8",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "{}"
- ]
- },
- "metadata": {},
- "execution_count": 24
- }
- ],
- "source": [
- "tbl.create_index(vector_column_name=\"vector\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ibNMrxyRt5HK"
- },
- "source": [
- "This is a helper method for analysing recommendations later.\n",
- "This method returns top N products that someone bought in the past (based on product quantity)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "id": "Uzgk5Od0t5HK"
- },
- "outputs": [],
- "source": [
- "def products_bought_by_user_in_the_past(user_id: int, top: int = 10):\n",
- "\n",
- " selected = data[data.user_id == user_id].sort_values(by=['total_orders'], ascending=False)\n",
- "\n",
- " selected['product_name'] = selected['product_id'].map(product_entries.set_index('product_id')['product_name'])\n",
- " selected = selected[['product_id', 'product_name', 'total_orders']].reset_index(drop=True)\n",
- " if selected.shape[0] < top:\n",
- " return selected\n",
- "\n",
- " return selected[:top]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ULyVnHEXt5HK"
- },
- "source": [
- "Let's retrieve our test users so we can query for recommendations."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "id": "Wwl7yFKTt5HK"
- },
- "outputs": [],
- "source": [
- "test_user_ids = [206210, 206211]\n",
- "test_user_factors = model.user_factors[user_to_index[test_user_ids]]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wTh61ou3t5HL"
- },
- "source": [
- "## Let's now query LanceDB to retrieve recommendations."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 868
- },
- "id": "UiZg4Iset5HL",
- "outputId": "edc08e77-c03f-4ded-fd1d-3fd9d8a91376"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " product_id product_name \\\n",
- "0 196 Soda \n",
- "1 46149 Zero Calorie Cola \n",
- "2 40939 Drinking Water \n",
- "3 37710 Trail Mix \n",
- "4 22802 Mineral Water \n",
- "5 41400 Crunchy Oats 'n Honey Granola Bars \n",
- "6 46061 Popcorn \n",
- "7 31651 Extra Fancy Unsalted Mixed Nuts \n",
- "8 5258 Sparkling Water \n",
- "9 38928 0% Greek Strained Yogurt \n",
- "\n",
- " vector _distance \n",
- "0 [-0.0030924827, -0.0042996905, -0.01350651, -0... 35.096085 \n",
- "1 [0.0015008126, -0.014029495, -0.015295635, 0.0... 35.392975 \n",
- "2 [0.0018837166, -0.018152414, -0.015649604, 0.0... 35.864483 \n",
- "3 [-0.0011668581, -0.0025222106, -0.016717039, -... 35.896873 \n",
- "4 [-0.010115783, -0.017115017, -0.011403508, 0.0... 36.035912 \n",
- "5 [0.0040870784, -0.0009994006, -0.018302424, -0... 36.042686 \n",
- "6 [0.0036969625, -0.013887798, -0.002804261, -0.... 36.043732 \n",
- "7 [0.014438897, -0.005578243, -0.0055169673, -0.... 36.117802 \n",
- "8 [-0.022658644, -0.026015628, -0.0083606485, -0... 36.131721 \n",
- "9 [0.0018425643, -0.011489441, -0.0052835834, 0.... 36.139870 "
- ],
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " product_id | \n",
- " product_name | \n",
- " vector | \n",
- " _distance | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 196 | \n",
- " Soda | \n",
- " [-0.0030924827, -0.0042996905, -0.01350651, -0... | \n",
- " 35.096085 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 46149 | \n",
- " Zero Calorie Cola | \n",
- " [0.0015008126, -0.014029495, -0.015295635, 0.0... | \n",
- " 35.392975 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 40939 | \n",
- " Drinking Water | \n",
- " [0.0018837166, -0.018152414, -0.015649604, 0.0... | \n",
- " 35.864483 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 37710 | \n",
- " Trail Mix | \n",
- " [-0.0011668581, -0.0025222106, -0.016717039, -... | \n",
- " 35.896873 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 22802 | \n",
- " Mineral Water | \n",
- " [-0.010115783, -0.017115017, -0.011403508, 0.0... | \n",
- " 36.035912 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 41400 | \n",
- " Crunchy Oats 'n Honey Granola Bars | \n",
- " [0.0040870784, -0.0009994006, -0.018302424, -0... | \n",
- " 36.042686 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 46061 | \n",
- " Popcorn | \n",
- " [0.0036969625, -0.013887798, -0.002804261, -0.... | \n",
- " 36.043732 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 31651 | \n",
- " Extra Fancy Unsalted Mixed Nuts | \n",
- " [0.014438897, -0.005578243, -0.0055169673, -0.... | \n",
- " 36.117802 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 5258 | \n",
- " Sparkling Water | \n",
- " [-0.022658644, -0.026015628, -0.0083606485, -0... | \n",
- " 36.131721 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 38928 | \n",
- " 0% Greek Strained Yogurt | \n",
- " [0.0018425643, -0.011489441, -0.0052835834, 0.... | \n",
- " 36.139870 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {}
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " product_id product_name total_orders\n",
- "0 46149 Zero Calorie Cola 50"
- ],
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " product_id | \n",
- " product_name | \n",
- " total_orders | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 46149 | \n",
- " Zero Calorie Cola | \n",
- " 50 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {}
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " product_id product_name \\\n",
- "0 26604 Organic Blackberries \n",
- "1 27845 Organic Whole Milk \n",
- "2 27966 Organic Raspberries \n",
- "3 43352 Raspberries \n",
- "4 9076 Blueberries \n",
- "5 21288 Blackberries \n",
- "6 39275 Organic Blueberries \n",
- "7 39928 Organic Kiwi \n",
- "8 11777 Red Raspberries \n",
- "9 21137 Organic Strawberries \n",
- "\n",
- " vector _distance \n",
- "0 [-0.017585486, 0.019628799, 0.0399348, 0.01422... 17.404045 \n",
- "1 [-0.050286394, 0.026924692, 0.030701049, -0.02... 17.404305 \n",
- "2 [-0.006732653, 0.015266006, 0.018316658, -0.00... 17.867121 \n",
- "3 [0.0037516877, 0.013682851, 0.057814274, 0.031... 18.030893 \n",
- "4 [0.0029817792, 0.030459687, 0.04528497, 0.0113... 18.135754 \n",
- "5 [-0.011553102, -0.010046569, 0.037375, 0.02368... 18.141661 \n",
- "6 [0.010543987, 0.006028164, 0.011502461, 0.0004... 18.241520 \n",
- "7 [-0.044292357, -0.031322725, -0.00174381, -0.0... 18.414057 \n",
- "8 [-0.0067819585, -0.023531102, 0.010277328, -0.... 18.468819 \n",
- "9 [0.007023127, 0.0037457773, -0.0061378656, -0.... 18.476973 "
- ],
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " product_id | \n",
- " product_name | \n",
- " vector | \n",
- " _distance | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 26604 | \n",
- " Organic Blackberries | \n",
- " [-0.017585486, 0.019628799, 0.0399348, 0.01422... | \n",
- " 17.404045 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 27845 | \n",
- " Organic Whole Milk | \n",
- " [-0.050286394, 0.026924692, 0.030701049, -0.02... | \n",
- " 17.404305 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 27966 | \n",
- " Organic Raspberries | \n",
- " [-0.006732653, 0.015266006, 0.018316658, -0.00... | \n",
- " 17.867121 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 43352 | \n",
- " Raspberries | \n",
- " [0.0037516877, 0.013682851, 0.057814274, 0.031... | \n",
- " 18.030893 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 9076 | \n",
- " Blueberries | \n",
- " [0.0029817792, 0.030459687, 0.04528497, 0.0113... | \n",
- " 18.135754 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 21288 | \n",
- " Blackberries | \n",
- " [-0.011553102, -0.010046569, 0.037375, 0.02368... | \n",
- " 18.141661 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 39275 | \n",
- " Organic Blueberries | \n",
- " [0.010543987, 0.006028164, 0.011502461, 0.0004... | \n",
- " 18.241520 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 39928 | \n",
- " Organic Kiwi | \n",
- " [-0.044292357, -0.031322725, -0.00174381, -0.0... | \n",
- " 18.414057 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 11777 | \n",
- " Red Raspberries | \n",
- " [-0.0067819585, -0.023531102, 0.010277328, -0.... | \n",
- " 18.468819 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 21137 | \n",
- " Organic Strawberries | \n",
- " [0.007023127, 0.0037457773, -0.0061378656, -0.... | \n",
- " 18.476973 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {}
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- " product_id product_name total_orders\n",
- "0 27845 Organic Whole Milk 49\n",
- "1 26604 Organic Blackberries 32"
- ],
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " product_id | \n",
- " product_name | \n",
- " total_orders | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 27845 | \n",
- " Organic Whole Milk | \n",
- " 49 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 26604 | \n",
- " Organic Blackberries | \n",
- " 32 | \n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ]
- },
- "metadata": {}
- }
- ],
- "source": [
- "# Query by user factors\n",
- "test_user_embeddings = test_user_factors.tolist()\n",
- "for embedding, id in zip(test_user_embeddings, test_user_ids):\n",
- " results = tbl.search(embedding).limit(10).to_pandas()\n",
- " display(results)\n",
- " display(products_bought_by_user_in_the_past(id, top=15))"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "gpuType": "T4",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.6"
- },
- "vscode": {
- "interpreter": {
- "hash": "5fe10bf018ef3e697f9035d60bf60847932a12bface18908407fd371fe880db9"
- }
- },
- "widgets": {
- "application/vnd.jupyter.widget-state+json": {
- "2c0101b0a3574a14b2a37fc431eb2908": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_31c3c90fa42f489796fba11d57799089",
- "IPY_MODEL_e13993dda2da40ff806d6e31a6e987d3",
- "IPY_MODEL_0bff70b647f3404fa15690ec9f3d0c78"
- ],
- "layout": "IPY_MODEL_674cf2d29d044cada59480813e0e8e58"
- }
- },
- "31c3c90fa42f489796fba11d57799089": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_bfd4ff099ed14ab1bd79233beea7f402",
- "placeholder": "",
- "style": "IPY_MODEL_000f9e8fd1db4bc0a7aceeb822ca2b2e",
- "value": "100%"
- }
- },
- "e13993dda2da40ff806d6e31a6e987d3": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_75b270d981de425ba1fd9a790b2a68ff",
- "max": 50,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_baafe1d810594384af1a5ffa4f2f5cb4",
- "value": 50
- }
- },
- "0bff70b647f3404fa15690ec9f3d0c78": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_bf95fd811f79425bb2248525aeab7da0",
- "placeholder": "",
- "style": "IPY_MODEL_46fb5083adf24ce4ae3fd4ea9aa4772e",
- "value": " 50/50 [17:28<00:00, 20.73s/it]"
- }
- },
- "674cf2d29d044cada59480813e0e8e58": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "bfd4ff099ed14ab1bd79233beea7f402": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "000f9e8fd1db4bc0a7aceeb822ca2b2e": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "75b270d981de425ba1fd9a790b2a68ff": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "baafe1d810594384af1a5ffa4f2f5cb4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "bf95fd811f79425bb2248525aeab7da0": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "46fb5083adf24ce4ae3fd4ea9aa4772e": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "5b98b7b242994c999064688c9210c61b": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_d5b1eb34ddc949aebd25b3744b93b726",
- "IPY_MODEL_752d37b9a68b42d284493645962f3782",
- "IPY_MODEL_f0def002c7ca41f6a70e9dba1bc605c7"
- ],
- "layout": "IPY_MODEL_4b0298a9ecf84b509fbf379d43339b9c"
- }
- },
- "d5b1eb34ddc949aebd25b3744b93b726": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_a37be209d5bb44e18f32c0259073d2c8",
- "placeholder": "",
- "style": "IPY_MODEL_b35984b48d8847eea119ee5eda049b9d",
- "value": "100%"
- }
- },
- "752d37b9a68b42d284493645962f3782": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_4b20ad4b356645bbbfb94929160943f2",
- "max": 192802,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_63b8646c732246988f566d0442a070e8",
- "value": 192802
- }
- },
- "f0def002c7ca41f6a70e9dba1bc605c7": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_ae8581ec76314304b2078759e1dbdd7e",
- "placeholder": "",
- "style": "IPY_MODEL_d0e90066f1ec42afa5f1c02551d3889e",
- "value": " 192802/192802 [02:11<00:00, 1657.77it/s]"
- }
- },
- "4b0298a9ecf84b509fbf379d43339b9c": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "a37be209d5bb44e18f32c0259073d2c8": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "b35984b48d8847eea119ee5eda049b9d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "4b20ad4b356645bbbfb94929160943f2": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "63b8646c732246988f566d0442a070e8": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "ae8581ec76314304b2078759e1dbdd7e": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "d0e90066f1ec42afa5f1c02551d3889e": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- }
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/examples/product-recommender/lancedb_cloud/main.py b/examples/product-recommender/lancedb_cloud/main.py
deleted file mode 100644
index af87c578..00000000
--- a/examples/product-recommender/lancedb_cloud/main.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import zipfile
-import numpy as np
-import pandas as pd
-import scipy.sparse
-import torch
-import implicit
-from implicit import evaluation
-import lancedb
-import pydantic
-from lancedb.pydantic import pydantic_to_schema, vector
-import argparse
-
-def products_bought_by_user_in_the_past(user_id: int, top: int = 10):
-
- selected = data[data.user_id == user_id].sort_values(by=['total_orders'], ascending=False)
-
- selected['product_name'] = selected['product_id'].map(product_entries.set_index('product_id')['product_name'])
- selected = selected[['product_id', 'product_name', 'total_orders']].reset_index(drop=True)
- if selected.shape[0] < top:
- return selected
-
- return selected[:top]
-
-def args_parse():
- parser = argparse.ArgumentParser(description='Product Recommender')
- parser.add_argument('--factors', type=int, default=128, help='dimension of latent factor vectors')
- parser.add_argument('--regularization', type=float, default=0.05, help='strength of penalty term')
- parser.add_argument('--iterations', type=int, default=50, help='number of iterations to update')
- parser.add_argument('--num-threads', type=int, default=1, help='amount of parallelization')
- parser.add_argument('--num-partitions', type=int, default=256, help='number of partitions of the index')
- parser.add_argument('--num-sub-vectors', type=int, default=16, help='number of sub-vectors (M) that will be created during Product Quantization (PQ).')
- args = parser.parse_args()
-
- return args
-
-files = [
- 'instacart-market-basket-analysis.zip',
- 'order_products__train.csv.zip',
- 'order_products__prior.csv.zip',
- 'products.csv.zip',
- 'orders.csv.zip'
-]
-
-if __name__ == "__main__":
- args = args_parse()
- for filename in files:
- with zipfile.ZipFile(filename, 'r') as zip_ref:
- zip_ref.extractall('./')
-
- products = pd.read_csv('products.csv')
- orders = pd.read_csv('orders.csv')
- order_products = pd.concat([pd.read_csv('order_products__train.csv'), pd.read_csv('order_products__prior.csv')])
-
- customer_order_products = pd.merge(orders, order_products, how='inner',on='order_id')
-
- # create confidence table
- data = customer_order_products.groupby(['user_id', 'product_id'])[['order_id']].count().reset_index()
- data.columns=["user_id", "product_id", "total_orders"]
- data.product_id = data.product_id.astype('int64')
-
- data_new = pd.DataFrame([[data.user_id.max() + 1, 46149, 50], # user 1 orders 50 Zero Calorie Cola
- [data.user_id.max() + 2, 27845, 49], # user 2 orders 49 Organic Whole Milk
- [data.user_id.max() + 2, 26604, 32] # user 2 orders 32 Organic Blackberries
- ], columns=['user_id', 'product_id', 'total_orders'])
- data = pd.concat([data, data_new]).reset_index(drop = True)
-
- # extract unique user and product ids
- unique_users = list(np.sort(data.user_id.unique()))
- unique_products = list(np.sort(products.product_id.unique()))
- purchases = list(data.total_orders)
-
- # create zero-based index position <-> user/item ID mappings
- index_to_user = pd.Series(unique_users)
-
- # create reverse mappings from user/item ID to index positions
- user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)
-
- # create row and column for user and product ids
- users_rows = data.user_id.astype(int)
- products_cols = data.product_id.astype(int)
-
- # create CSR matrix
- matrix = scipy.sparse.csr_matrix((purchases, (users_rows, products_cols)), shape=(len(unique_users) + 1, len(unique_products) + 1))
- matrix.data = np.nan_to_num(matrix.data, copy=False)
-
- #split data into train and test splits
- train, test = evaluation.train_test_split(matrix, train_percentage=0.9)
-
- # initialize the recommender model
- model = implicit.als.AlternatingLeastSquares(factors=args.factors,
- regularization=args.regularization,
- iterations=args.iterations,
- num_threads=args.num_threads)
-
- alpha = 15
- train = (train * alpha).astype('double')
-
- # train the model on CSR matrix
- model.fit(train, show_progress = True)
-
- test = (test * alpha).astype('double')
- evaluation.ranking_metrics_at_k(model, train, test, K=100,
- show_progress=True, num_threads=1)
-
-
- db_url = "db://your-project-name"
- api_key="sk_..."
- region = "us-east-1"
- db = lancedb.connect(db_url, api_key=api_key, region=region)
- class ProductModel(pydantic.BaseModel):
- product_id: int
- product_name: str
- vector: vector(args.factors)
- schema = pydantic_to_schema(ProductModel)
- table_name = 'product_recommender'
- tbl = db.create_table(table_name, schema=schema)
-
- # Transform items into factors
- items_factors = model.item_factors
- product_entries = products[['product_id', 'product_name']].drop_duplicates()
- product_entries['product_id'] = product_entries.product_id.astype('int64')
- device = "cuda" if torch.cuda.is_available() else "cpu"
- item_embeddings = items_factors[1:].tolist()
- product_entries['vector'] = item_embeddings
-
- tbl.add(product_entries)
- tbl.create_index(vector_column_name="vector")
-
- test_user_ids = [206210, 206211]
- test_user_factors = model.user_factors[user_to_index[test_user_ids]]
-
- # Query by user factors
- test_user_embeddings = test_user_factors.tolist()
- for embedding, id in zip(test_user_embeddings, test_user_ids):
- results = tbl.search(embedding).limit(10).to_pandas()
- print(results.drop(columns=['vector']).to_string(max_cols=None))
- print(products_bought_by_user_in_the_past(id, top=15).to_string(max_cols=None))
diff --git a/examples/product-recommender/lancedb_cloud/requirements.txt b/examples/product-recommender/lancedb_cloud/requirements.txt
deleted file mode 100644
index 662caa62..00000000
--- a/examples/product-recommender/lancedb_cloud/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-numpy
-pandas
-scipy
-kaggle
-implicit
-torch
-lancedb
\ No newline at end of file