From 52dc5d8afd4761e543ad18a2fe3127991f09b4f4 Mon Sep 17 00:00:00 2001 From: Ayush Date: Wed, 24 Jan 2024 01:10:30 +0530 Subject: [PATCH] fix product recommender --- examples/product-recommender/main.ipynb | 5445 ++++++++++++----------- 1 file changed, 2904 insertions(+), 2541 deletions(-) diff --git a/examples/product-recommender/main.ipynb b/examples/product-recommender/main.ipynb index 53c4d96f..7c631fae 100644 --- a/examples/product-recommender/main.ipynb +++ b/examples/product-recommender/main.ipynb @@ -1,2595 +1,2958 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "YmdWGrw4t5G2" - }, - "source": [ - "# Product Recommender using Collaborative Filtering and LanceDB\n", - "\n", - "We are going to use **LanceDB** and **Collaborative Filtering** to recommend products based on a user's past buying history. We used the **Instacart dataset** as our data for this example.\n", - "\n", - "![picture](https://daxg39y63pxwu.cloudfront.net/images/blog/product-recommendation-system-projects/Product_Recommendation_System_Project_Ideas_and_Examples.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lXd46ecEt5G7" - }, - "source": [ - "To run this example, you must first create a Kaggle account. Then, go to the 'Account' tab of your user profile and select 'Create New Token'. This will trigger the download of kaggle.json, a file containing your API credentials.\n", - "\n", - "Add Kaggle credentials to `~/.kaggle/kaggle.json` on Linux, OSX, and other UNIX-based operating systems or `C:\\Users\\\\.kaggle\\kaggle.json` for Window's users.\n", - "\n", - "In Google Colab, run the snippet below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "l6TTPIF_omEy" - }, - "outputs": [], - "source": [ - "import json\n", - "\n", - "with open('/root/.kaggle/kaggle.json', 'w') as fp:\n", - " fp.write(json.dumps({\"username\":\"YOUR_USERNAME\",\"key\":\"YOUR_SECRET_KEY\"}))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Install dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "YmdWGrw4t5G2" + }, + "source": [ + "# Product Recommender using Collaborative Filtering and LanceDB\n", + "\n", + "We are going to use **LanceDB** and **Collaborative Filtering** to recommend products based on a user's past buying history. We used the **Instacart dataset** as our data for this example.\n", + "\n", + "![picture](https://daxg39y63pxwu.cloudfront.net/images/blog/product-recommendation-system-projects/Product_Recommendation_System_Project_Ideas_and_Examples.png)" + ] }, - "id": "R3_Hq2VC4_zT", - "outputId": "f55d20d1-9953-457c-f41b-d09912b06188" - }, - "outputs": [], - "source": [ - "!pip install numpy pandas scipy kaggle implicit torch lancedb" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i_eatRhaIGIz" - }, - "source": [ - "### Importing libraries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + { + "cell_type": "markdown", + "metadata": { + "id": "lXd46ecEt5G7" + }, + "source": [ + "To run this example, you must first create a Kaggle account. Then, go to the 'Account' tab of your user profile and select 'Create New Token'. This will trigger the download of kaggle.json, a file containing your API credentials.\n", + "\n", + "Add Kaggle credentials to `~/.kaggle/kaggle.json` on Linux, OSX, and other UNIX-based operating systems or `C:\\Users\\\\.kaggle\\kaggle.json` for Window's users.\n", + "\n", + "In Google Colab, run the snippet below." + ] }, - "id": "emp_MSXZt5G8", - "outputId": "b9719d21-25a6-461e-d571-164f04e599d0" - }, - "outputs": [], - "source": [ - "import zipfile\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy.sparse\n", - "import torch\n", - "import implicit\n", - "from implicit import evaluation\n", - "import pydantic\n", - "import lancedb\n", - "from lancedb.pydantic import pydantic_to_schema, vector" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bUGkN85V4_zY" - }, - "source": [ - "### Load the dataset\n", - "Now we can download the dataset. You will need to accept the rules of the `instacart-market-basket-analysis` competition, which you can do so [here](https://www.kaggle.com/competitions/instacart-market-basket-analysis/rules)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "l6TTPIF_omEy", + "outputId": "d2cf1685-103e-4b62-bae3-a16d171a928f", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Kaggle API key file created and moved successfully.\n" + ] + } + ], + "source": [ + "import json\n", + "import os\n", + "\n", + "# Set the file path\n", + "kaggle_json_path = '/content/kaggle.json'\n", + "\n", + "# Write Kaggle API key to the file\n", + "with open(kaggle_json_path, 'w') as fp:\n", + " json.dump({\"username\": \"\", \"key\": \"\"}, fp)\n", + "\n", + "# Move the file to the correct location\n", + "os.system('mkdir -p ~/.kaggle')\n", + "os.system(f'mv {kaggle_json_path} ~/.kaggle/kaggle.json')\n", + "\n", + "# Set permissions\n", + "os.system('chmod 600 ~/.kaggle/kaggle.json')\n", + "\n", + "print(\"Kaggle API key file created and moved successfully.\")" + ] }, - "id": "09gdQyBu4_zY", - "outputId": "4a68ac04-2c9f-4a11-a7da-f24e34270256" - }, - "outputs": [], - "source": [ - "!kaggle competitions download -c instacart-market-basket-analysis" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K4Q4cOX-4_zY" - }, - "source": [ - "We must now extract the zip files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f3g296nL4_zZ" - }, - "outputs": [], - "source": [ - "files = [\n", - " 'instacart-market-basket-analysis.zip',\n", - " 'order_products__train.csv.zip',\n", - " 'order_products__prior.csv.zip',\n", - " 'products.csv.zip',\n", - " 'orders.csv.zip'\n", - "]\n", - "\n", - "for filename in files:\n", - " with zipfile.ZipFile(filename, 'r') as zip_ref:\n", - " zip_ref.extractall('./')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oLgkRIfq4_zZ" - }, - "source": [ - "Now we can move on to loading the dataset. We'll first read the csv files and create dataframes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cBbbR7Rut5G_" - }, - "outputs": [], - "source": [ - "products = pd.read_csv('products.csv')\n", - "orders = pd.read_csv('orders.csv')\n", - "order_products = pd.concat([pd.read_csv('order_products__train.csv'), pd.read_csv('order_products__prior.csv')])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5FV_GGjst5HA" - }, - "source": [ - "Since there isn't a user rating attribute, we'll gather \"confidence\" data by looking at the frequency of each item purchased by a user, and store this in the `data` dataframe." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Data Manipulation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZjRh7RYpt5HB" - }, - "outputs": [], - "source": [ - "customer_order_products = pd.merge(orders, order_products, how='inner',on='order_id')\n", - "\n", - "# create confidence table\n", - "data = customer_order_products.groupby(['user_id', 'product_id'])[['order_id']].count().reset_index()\n", - "data.columns=[\"user_id\", \"product_id\", \"total_orders\"]\n", - "data.product_id = data.product_id.astype('int64')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "77lvwm0St5HC" - }, - "source": [ - "Let's create a couple of test users to examine the recommendations later:\n", - "- 1st test user: buys 50 sodas: **Zero Calorie Cola**\n", - "- 2nd test user: buys organic produce: **Organic Whole Milk** and **Organic Blackberries**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 206 + { + "cell_type": "markdown", + "metadata": { + "id": "c6G45HrUqNx5" + }, + "source": [ + "### Install dependencies" + ] }, - "id": "A06EfAf-t5HC", - "outputId": "135aa843-f46a-4fef-d2ae-ddfc00512ad2" - }, - "outputs": [ { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idproduct_idtotal_orders
13863744206209486971
13863745206209487422
138637462062104614950
138637472062112784549
138637482062112660432
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R3_Hq2VC4_zT", + "outputId": "ee47bbd5-d1c3-4900-894e-2530190e17e7" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.23.5)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)\n", + "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)\n", + "Collecting implicit\n", + " Downloading implicit-0.7.2-cp310-cp310-manylinux2014_x86_64.whl (8.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n", + "Collecting lancedb\n", + " Downloading lancedb-0.5.0-py3-none-any.whl (87 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)\n", + "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2023.11.17)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.1)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.1)\n", + "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n", + "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", + "Requirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from implicit) (3.2.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", + "Collecting deprecation (from lancedb)\n", + " Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pylance==0.9.6 (from lancedb)\n", + " Downloading pylance-0.9.6-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.6/18.6 MB\u001b[0m \u001b[31m58.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n", + " Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n", + "Collecting retry>=0.9.2 (from lancedb)\n", + " Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n", + "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (1.10.13)\n", + "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n", + "Collecting semver>=3.0 (from lancedb)\n", + " Downloading semver-3.0.2-py3-none-any.whl (17 kB)\n", + "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.2)\n", + "Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (6.0.1)\n", + "Requirement already satisfied: click>=8.1.7 in /usr/local/lib/python3.10/dist-packages (from lancedb) (8.1.7)\n", + "Collecting overrides>=0.7 (from lancedb)\n", + " Downloading overrides-7.6.0-py3-none-any.whl (17 kB)\n", + "Collecting pyarrow>=12 (from pylance==0.9.6->lancedb)\n", + " Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.6)\n", + "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n", + "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n", + " Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from deprecation->lancedb) (23.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Installing collected packages: ratelimiter, semver, pyarrow, py, overrides, deprecation, retry, pylance, implicit, lancedb\n", + " Attempting uninstall: pyarrow\n", + " Found existing installation: pyarrow 10.0.1\n", + " Uninstalling pyarrow-10.0.1:\n", + " Successfully uninstalled pyarrow-10.0.1\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 15.0.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed deprecation-2.1.0 implicit-0.7.2 lancedb-0.5.0 overrides-7.6.0 py-1.11.0 pyarrow-15.0.0 pylance-0.9.6 ratelimiter-1.2.0.post0 retry-0.9.2 semver-3.0.2\n" + ] + } ], - "text/plain": [ - " user_id product_id total_orders\n", - "13863744 206209 48697 1\n", - "13863745 206209 48742 2\n", - "13863746 206210 46149 50\n", - "13863747 206211 27845 49\n", - "13863748 206211 26604 32" + "source": [ + "!pip install numpy pandas scipy kaggle implicit torch lancedb" ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_new = pd.DataFrame([[data.user_id.max() + 1, 46149, 50],\n", - " [data.user_id.max() + 2, 27845, 49],\n", - " [data.user_id.max() + 2, 26604, 32]\n", - " ], columns=['user_id', 'product_id', 'total_orders'])\n", - "data = pd.concat([data, data_new]).reset_index(drop = True)\n", - "data.tail()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xBC-8PFTt5HD" - }, - "source": [ - "In the next step, we will extract user and product unique ids, in order to create a `CSR (Compressed Sparse Row)` matrix. This will allow us to perform collaborative filtering.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "v2_2R7zmt5HE" - }, - "outputs": [], - "source": [ - "# extract unique user and product ids\n", - "unique_users = list(np.sort(data.user_id.unique()))\n", - "unique_products = list(np.sort(products.product_id.unique()))\n", - "purchases = list(data.total_orders)\n", - "\n", - "# create zero-based index position <-> user/item ID mappings\n", - "index_to_user = pd.Series(unique_users)\n", - "\n", - "# create reverse mappings from user/item ID to index positions\n", - "user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)\n", - "\n", - "# create row and column for user and product ids\n", - "users_rows = data.user_id.astype(int)\n", - "products_cols = data.product_id.astype(int)\n", - "\n", - "# create CSR matrix\n", - "matrix = scipy.sparse.csr_matrix((purchases, (users_rows, products_cols)), shape=(len(unique_users) + 1, len(unique_products) + 1))\n", - "matrix.data = np.nan_to_num(matrix.data, copy=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "II6wOH96t5HF" - }, - "source": [ - "Let's now create a recommender model using the **implicit** library. The recommendation model is based off the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](https://www.researchgate.net/publication/220765111_Collaborative_Filtering_for_Implicit_Feedback_Datasets) with performance optimizations described in [Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.379.6473&rep=rep1&type=pdf).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Difference between colloborative and content filtering\n", - "\n", - "![picture](https://miro.medium.com/v2/resize:fit:1400/0*R8qw_CXxCc4600bQ.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 49, - "referenced_widgets": [ - "1a0967f7267a4774ad99cedcb45f4cfc", - "ae54a8d76aca4327a938dda5c93c78ca", - "be5ee6d5b76b4deab6e84fdc8931b899", - "7e17ec4dd0484ac1b2a0ace262b6b518", - "ac143cc6e1394eb6963d1e6995d09447", - "5e66de2e9882465b8ba1f252f049785a", - "4cf95f3891134632b14c606c744fded3", - "734cebd21ff5403891fe799826d1d98b", - "aafb9c85a5b74f18bbbc41fc3132dd1e", - "261d4713dfd040318b5f542609a05129", - "6ca9eb350a6b4a1b8bfdb500ce00c0c4" - ] }, - "id": "k0GW99kxt5HF", - "outputId": "3d8e7947-bdfe-460a-a8bc-8f32631aaa36" - }, - "outputs": [], - "source": [ - "#split data into train and test splits\n", - "train, test = evaluation.train_test_split(matrix, train_percentage=0.9)\n", - "\n", - "# initialize the recommender model\n", - "model = implicit.als.AlternatingLeastSquares(factors=128,\n", - " regularization=0.05,\n", - " iterations=50,\n", - " num_threads=1)\n", - "\n", - "alpha = 15\n", - "train = (train * alpha).astype('double')\n", - "\n", - "# train the model on CSR matrix\n", - "model.fit(train, show_progress = True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yN80hSojt5HF" - }, - "source": [ - "## Let's now evaluate the model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 120, - "referenced_widgets": [ - "9bff71b52a854ba788601810574f7689", - "1083778f7a0941e18a436cc7c6cddda8", - "6336bd02113245f7883a48961d80ed55", - "bffd8e57b04946ae9289af62a52b9fc8", - "03397ba7fce7452593eddb4fa4fd27f4", - "ccd7286cfb904ad8a5486738357a6d39", - "d157f99a712c43cdb7a26c6470d0e615", - "c9110b853367476a8743eea0c522a265", - "4ae829ed16f846edade21e0730ff6f37", - "95e65e6bceed40db9cea2a0adffed3fb", - "cbff0accbd0f492597900c93a340d59f" - ] - }, - "id": "BbD8of_nt5HG", - "outputId": "91ce36fd-f569-41c9-859e-71cf9709f8d4" - }, - "outputs": [], - "source": [ - "test = (test * alpha).astype('double')\n", - "evaluation.ranking_metrics_at_k(model, train, test, K=100,\n", - " show_progress=True, num_threads=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LNmva3Dlt5HG" - }, - "source": [ - "From the model, we'll be able to retrieve item and user factors, which we can use later on to store in LanceDB as vector embeddings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + { + "cell_type": "markdown", + "metadata": { + "id": "i_eatRhaIGIz" + }, + "source": [ + "### Importing libraries" + ] }, - "id": "JUtCROQKt5HG", - "outputId": "38c55345-deec-4555-f9fd-33d88742b8de" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Matrix([[-4.78145870e-04 1.20844017e-03 -1.05093475e-02 4.69897687e-03\n", - " -3.42543889e-03 -2.47619092e-03 -1.37619404e-02 7.47181184e-04\n", - " 1.07308161e-02 -2.85757496e-03 -3.43888951e-03 2.04937998e-02\n", - " 1.51449145e-04 -2.15489650e-03 4.52879071e-03 2.36469251e-03\n", - " -4.35322057e-03 1.33156916e-02 1.48590095e-02 4.69916826e-03\n", - " 6.51248451e-03 4.22086829e-04 -8.89686961e-03 -1.12102665e-02\n", - " -1.30706644e-02 2.08967202e-03 1.39534501e-02 -5.01955580e-03\n", - " 9.95233562e-03 1.66954547e-02 -2.05567423e-02 2.21748278e-03\n", - " -1.06390044e-02 1.60855800e-02 -5.87939285e-03 2.46607186e-03\n", - " -4.01218655e-03 -6.49328623e-03 6.99202390e-03 1.05327908e-02\n", - " 6.51289755e-03 -9.16731264e-03 -4.96822828e-03 8.40877462e-03\n", - " -2.60996539e-03 5.20697143e-03 -4.72197018e-04 -6.58531254e-03\n", - " -1.40383225e-02 -3.83673515e-03 -1.17233172e-02 8.79578851e-03\n", - " 3.15940916e-03 -1.96590065e-03 3.96021921e-03 7.77690002e-05\n", - " -1.72196236e-03 -9.86298453e-03 -1.31952651e-02 -3.54522630e-03\n", - " -1.41597521e-02 6.59553846e-03 -1.51649360e-02 1.02147097e-02\n", - " -2.20241174e-02 -3.74605908e-04 -8.67480133e-03 5.97078400e-03\n", - " -2.66844733e-03 -1.42379839e-03 2.27073785e-02 5.13940537e-03\n", - " 1.03549734e-02 6.50122017e-03 -1.99826285e-02 9.24922712e-03\n", - " -2.38972181e-03 -4.23666090e-03 -1.25027373e-02 1.88295334e-03\n", - " 3.62423621e-03 3.57046886e-03 -3.08311475e-03 5.31456666e-03\n", - " 1.25467628e-02 3.57796950e-03 7.07062241e-03 5.72030572e-03\n", - " 5.58650494e-03 -6.72255456e-03 -1.83661487e-02 5.99548966e-03\n", - " -2.72030273e-04 -1.51731428e-02 -1.21864192e-02 8.52781814e-03\n", - " -4.08741675e-04 7.61474657e-04 2.11992972e-02 6.37246016e-03\n", - " 5.38055226e-03 4.97937202e-03 -2.71766027e-03 -9.19679552e-03\n", - " -7.30522675e-03 1.29338922e-02 1.18429838e-02 6.61867904e-03\n", - " 1.70149212e-03 -1.11316633e-03 8.28325748e-03 6.30424637e-03\n", - " 3.23107629e-03 -7.33668311e-03 -4.62154718e-03 5.39950328e-03\n", - " -1.39445011e-02 -1.06749346e-03 -2.00939714e-03 -1.24352388e-02\n", - " -1.44120287e-02 1.06749088e-02 1.21331098e-03 8.06814246e-03\n", - " -1.87468696e-02 -2.49057007e-03 -1.15206158e-02 7.64290895e-03]\n", - " [-1.86744879e-03 -4.66924999e-03 2.62058136e-04 3.96398274e-04\n", - " 8.63511232e-04 8.07967386e-04 -1.68662157e-03 -1.08173559e-03\n", - " -1.80625229e-03 -2.72132549e-03 2.05651508e-03 9.63525788e-04\n", - " 4.06941352e-03 1.13172247e-03 -1.96925621e-03 -1.47615559e-03\n", - " -1.19289500e-04 -1.21552637e-03 -6.69758883e-04 3.02351615e-03\n", - " -2.16358632e-04 -2.93541444e-03 1.15798367e-03 -3.25729634e-05\n", - " 1.56907842e-03 2.24484876e-03 7.82305433e-04 4.21060366e-04\n", - " -1.10215694e-03 -4.49104747e-03 -5.18820365e-04 1.77382247e-03\n", - " -1.31249835e-03 -1.62468455e-03 -8.22964241e-04 4.19229531e-04\n", - " -1.61156946e-04 -4.35678870e-04 4.96705237e-04 1.69218823e-04\n", - " -3.39794345e-03 -1.33699877e-03 2.80402927e-03 6.60935591e-04\n", - " -1.85639539e-03 -1.55835948e-03 4.32507455e-04 -1.05694158e-03\n", - " -4.86606266e-04 -1.27697759e-03 -1.47943050e-04 1.07673078e-03\n", - " -1.01191479e-04 -9.65903630e-04 -2.48108618e-03 8.08388926e-04\n", - " 1.76821987e-03 2.34587677e-03 1.47300051e-03 2.15368299e-03\n", - " -8.25742783e-04 6.46742876e-04 2.10412848e-03 2.22956366e-03\n", - " -7.51944390e-05 -4.20888158e-04 -9.83789214e-04 -2.41565169e-03\n", - " 2.28196708e-03 -1.67180516e-03 -1.01774244e-03 -1.60270859e-03\n", - " -8.59596417e-04 -1.55605783e-04 5.09508187e-04 -1.87147816e-03\n", - " 1.08829510e-04 9.37477977e-04 1.45089245e-04 1.11689174e-03\n", - " 1.42429711e-03 2.01329254e-04 2.59526510e-04 -5.31102007e-04\n", - " 7.97511311e-05 -8.36311374e-04 2.43162527e-03 1.07298889e-04\n", - " 3.93533992e-04 3.79959791e-04 -2.30675470e-03 9.03452164e-04\n", - " -1.05375238e-03 -1.16557023e-03 -2.24862527e-03 -8.47899537e-06\n", - " 3.58379795e-03 -1.97861204e-03 7.35816080e-04 2.40693684e-03\n", - " 4.44425794e-04 1.32714782e-03 2.48055498e-04 -2.01600906e-03\n", - " 3.14348348e-04 2.91592046e-03 -4.29794611e-03 8.71473225e-04\n", - " -4.66790429e-04 -2.59004976e-03 3.57797777e-04 -5.13199251e-04\n", - " -3.72246868e-04 -4.57473885e-04 -1.13439921e-03 -5.64167567e-04\n", - " 4.08002222e-03 1.71154004e-03 -1.18066662e-03 1.10202690e-03\n", - " -9.25508677e-04 9.13127908e-04 -3.54050135e-04 -1.86034129e-03\n", - " 6.16788166e-04 -1.34601502e-03 -6.19047554e-04 6.99885830e-04]])" + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "emp_MSXZt5G8" + }, + "outputs": [], + "source": [ + "import zipfile\n", + "import numpy as np\n", + "import pandas as pd\n", + "import scipy.sparse\n", + "import torch\n", + "import implicit\n", + "from implicit import evaluation\n", + "import pydantic\n", + "import lancedb\n", + "from lancedb.pydantic import pydantic_to_schema, vector" ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.item_factors[1:3]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "O3onbJmnt5HG", - "outputId": "988d73ee-a007-4848-d055-bc5b34f8900b" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Matrix([[ 1.76306021e+00 5.84278941e-01 -1.31811535e+00 2.58672982e-02\n", - " -1.47269890e-01 -1.99104631e+00 2.27232277e-01 -1.44048560e+00\n", - " 6.22047544e-01 -6.87879205e-01 -4.23879363e-02 -1.50391304e+00\n", - " -1.04289226e-01 -9.02593374e-01 6.43032670e-01 -3.58335793e-01\n", - " -6.54706135e-02 8.50856245e-01 7.79175341e-01 -4.51985866e-01\n", - " 8.55366886e-01 1.36438921e-01 -5.49016356e-01 4.83980298e-01\n", - " -1.40851259e-01 -5.33492684e-01 5.68639338e-01 1.45152867e-01\n", - " 1.76261580e+00 5.22969246e-01 -2.21816874e+00 1.65968144e+00\n", - " -8.83751035e-01 7.76956260e-01 1.25151992e+00 -3.25308472e-01\n", - " -1.49347281e+00 -1.01729310e+00 -4.59959418e-01 7.20718205e-01\n", - " -7.15589583e-01 6.45604208e-02 -8.51610005e-01 3.01664054e-01\n", - " -4.82483774e-01 -1.79249153e-01 -1.40011147e-01 -4.23951089e-01\n", - " -1.13460481e+00 1.13597369e+00 -7.75141537e-01 1.12328935e+00\n", - " 8.09678361e-02 3.74672678e-03 -5.49308121e-01 3.00086081e-01\n", - " -3.05186462e+00 8.97184253e-01 3.36005628e-01 -9.53863025e-01\n", - " -1.40538502e+00 8.67709160e-01 -1.28687751e+00 2.50722909e+00\n", - " -1.39235985e+00 8.41415584e-01 -3.63795489e-01 -4.97452825e-01\n", - " 2.82425970e-01 -1.41166592e+00 1.87965715e+00 2.28982329e+00\n", - " 2.65981466e-01 1.40662909e+00 -5.27400672e-01 2.24494290e+00\n", - " -1.02008140e+00 9.12805080e-01 -1.06372535e+00 -8.54292572e-01\n", - " 7.36877501e-01 2.80537099e-01 -1.14187610e+00 3.33846539e-01\n", - " 5.12147248e-01 -1.92265600e-01 -9.45136070e-01 6.38769329e-01\n", - " -1.62952110e-01 1.63383838e-02 -2.03721905e+00 1.65726721e+00\n", - " 1.24682081e+00 1.13641596e+00 -1.21482491e+00 1.33682203e+00\n", - " 1.25738907e+00 1.26730061e+00 4.19435412e-01 5.83988726e-01\n", - " -8.63369167e-01 -2.17941441e-02 -9.48838413e-01 -1.69335604e-01\n", - " -7.81806767e-01 1.10969472e+00 -3.43311757e-01 7.97569871e-01\n", - " 7.03379154e-01 5.70429564e-02 -1.25305027e-01 7.07163095e-01\n", - " -5.44186294e-01 1.48075533e+00 -4.71153051e-01 5.30716360e-01\n", - " -3.40850413e-01 -6.89384997e-01 3.17025274e-01 -8.00609946e-01\n", - " -1.77810180e+00 7.57392883e-01 -8.67088586e-02 -1.57348573e+00\n", - " -1.63898838e+00 5.90374112e-01 -9.29109752e-02 -1.37256697e-01]\n", - " [ 3.68525267e+00 2.21822786e+00 -3.72133762e-01 -9.38002110e-01\n", - " -1.82598040e-01 -1.57692325e+00 -1.60580468e+00 7.22465396e-01\n", - " -1.39234257e+00 -1.21537292e+00 -2.97544384e+00 -1.43866599e+00\n", - " 2.61079460e-01 -7.03472435e-01 4.57634866e-01 -9.56056535e-01\n", - " 1.93661535e+00 -1.24941671e+00 -2.09334850e+00 1.10176265e+00\n", - " 2.63000101e-01 5.81728339e-01 7.04845428e-01 -3.02078456e-01\n", - " 1.77111065e+00 1.04489577e+00 1.51791990e+00 2.20976304e-02\n", - " -7.39717960e-01 -6.22645319e-01 -1.56369075e-01 1.04076886e+00\n", - " 1.25483823e+00 -3.32873642e-01 4.33149636e-01 -1.22933829e+00\n", - " -7.20005214e-01 5.78506827e-01 1.00301087e+00 8.72102618e-01\n", - " 2.50873059e-01 8.88309538e-01 -2.71842957e-01 -9.80645061e-01\n", - " 1.66997731e+00 -1.29030275e+00 1.46061301e-01 3.29237163e-01\n", - " -2.51785922e+00 -3.69922400e-01 4.70215261e-01 2.95814991e+00\n", - " 7.99638331e-01 -7.45547190e-02 1.54420769e+00 1.31914115e+00\n", - " -5.89660287e-01 -9.22154129e-01 -1.61234534e+00 -1.25994372e+00\n", - " 1.22705674e+00 7.55788326e-01 -9.14623439e-01 -1.08628643e+00\n", - " 1.54585168e-01 -6.50843084e-01 7.14638084e-02 2.60113406e+00\n", - " -3.26147604e+00 -7.06203341e-01 -1.01394987e+00 1.81137729e+00\n", - " -1.33361208e+00 6.01522505e-01 -2.62121201e+00 1.60872042e+00\n", - " 1.40612996e+00 -2.63518363e-01 2.49749565e+00 7.56558836e-01\n", - " 9.12506133e-03 4.42859173e-01 3.19830328e-01 -1.19115925e+00\n", - " -3.87303567e+00 4.34802413e-01 -1.38765180e+00 1.02910614e+00\n", - " -6.26348317e-01 8.68485034e-01 -1.63119793e+00 6.86433792e-01\n", - " 4.61489171e-01 -1.19542277e+00 3.53666723e-01 -7.63388455e-01\n", - " 1.00943439e-01 2.46762371e+00 3.33791876e+00 4.14971447e+00\n", - " -1.56401730e+00 5.20050287e-01 -1.22738254e+00 5.72006464e-01\n", - " 7.93922722e-01 1.36654353e+00 1.30395845e-01 -1.13550574e-01\n", - " -1.27993798e+00 1.18952966e+00 -2.28072000e+00 -1.86128521e+00\n", - " -1.75689876e+00 2.94097590e+00 -1.61757767e-01 6.65994763e-01\n", - " -2.79640174e+00 -7.66005337e-01 1.12821102e-01 -3.13482106e-01\n", - " 5.64184427e-01 -2.94495314e-01 1.34794497e+00 -1.06749706e-01\n", - " 2.95547748e+00 7.01298356e-01 2.26510024e+00 3.02275372e+00]])" + "cell_type": "markdown", + "metadata": { + "id": "bUGkN85V4_zY" + }, + "source": [ + "### Load the dataset\n", + "Now we can download the dataset. You will need to accept the rules of the `instacart-market-basket-analysis` competition, which you can do so [here](https://www.kaggle.com/competitions/instacart-market-basket-analysis/rules)." ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.user_factors[1:3]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "38rssdYCBR4E" - }, - "source": [ - "## Let's save the data and create a empty LanceDB Table using a Pydantic model.\n", - "A Table is designed to store large numbers of columns and huge quantities of data! For those interested, a LanceDB is columnar-based, and uses Lance, an open data format to store data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3_ykVLT6t5HH" - }, - "outputs": [], - "source": [ - "db = lancedb.connect(\"data/lancedb\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ufHsF0o4t5HI" - }, - "outputs": [], - "source": [ - "class ProductModel(pydantic.BaseModel):\n", - " product_id: int\n", - " product_name: str\n", - " vector: vector(128)\n", - "schema = pydantic_to_schema(ProductModel)\n", - "table_name = 'product_recommender'\n", - "tbl = db.create_table(table_name, schema=schema, mode=\"overwrite\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0-2K-g4-t5HJ" - }, - "source": [ - "Let's now store our item factors into the table via the vector column of `product_entries`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NOOPF9zOt5HJ" - }, - "outputs": [], - "source": [ - "# Transform items into factors\n", - "items_factors = model.item_factors\n", - "product_entries = products[['product_id', 'product_name']].drop_duplicates()\n", - "product_entries['product_id'] = product_entries.product_id.astype('int64')\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "item_embeddings = items_factors[1:].to_numpy().tolist() if device == \"cuda\" else items_factors[1:].tolist()\n", - "product_entries['vector'] = item_embeddings\n", - "\n", - "tbl.add(product_entries)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j3aU4z-tSbWE" - }, - "source": [ - "## Let's create an ANN index in order to speed up retrieval. This might take a while." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "H8HyvjCFSeaz" - }, - "outputs": [], - "source": [ - "tbl.create_index(num_partitions=256, num_sub_vectors=16)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ibNMrxyRt5HK" - }, - "source": [ - "This is a helper method for analysing recommendations later.\n", - "This method returns top N products that someone bought in the past (based on product quantity)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Uzgk5Od0t5HK" - }, - "outputs": [], - "source": [ - "def products_bought_by_user_in_the_past(user_id: int, top: int = 10):\n", - "\n", - " selected = data[data.user_id == user_id].sort_values(by=['total_orders'], ascending=False)\n", - "\n", - " selected['product_name'] = selected['product_id'].map(product_entries.set_index('product_id')['product_name'])\n", - " selected = selected[['product_id', 'product_name', 'total_orders']].reset_index(drop=True)\n", - " if selected.shape[0] < top:\n", - " return selected\n", - "\n", - " return selected[:top]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ULyVnHEXt5HK" - }, - "source": [ - "Let's retrieve our test users so we can query for recommendations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Wwl7yFKTt5HK" - }, - "outputs": [], - "source": [ - "test_user_ids = [206210, 206211]\n", - "test_user_factors = model.user_factors[user_to_index[test_user_ids]]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wTh61ou3t5HL" - }, - "source": [ - "## Let's now query LanceDB to retrieve recommendations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 868 }, - "id": "UiZg4Iset5HL", - "outputId": "797e1876-b943-48c9-eced-a595011ca33a" - }, - "outputs": [ { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_namevector_distance
046149Zero Calorie Cola[0.0022252752, 0.006192103, -0.030976068, -0.0...42.362167
1196Soda[-0.026970126, -0.018141357, -0.058909897, -0....42.566917
240939Drinking Water[-0.0022220984, 0.0038040192, -0.014334261, -0...42.823524
341400Crunchy Oats 'n Honey Granola Bars[0.00020495932, 0.01123721, -0.017992454, 0.00...42.825871
437710Trail Mix[-0.00409454, 0.013049758, -0.015458386, 0.002...42.887043
546061Popcorn[-0.0024577982, 0.0062640505, 0.0007137757, -0...42.894321
6389280% Greek Strained Yogurt[0.0039202925, -0.0039743707, -0.012298337, 0....42.912727
731651Extra Fancy Unsalted Mixed Nuts[0.0035735483, -0.006829423, -0.009457169, 0.0...42.922546
822802Mineral Water[0.02247884, 0.003889028, -0.020661984, -0.031...42.935951
939657Milk Chocolate Almonds[0.00749532, 0.00577313, -0.016842585, -0.0015...42.946739
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "09gdQyBu4_zY", + "outputId": "bb92fb9e-df75-47a5-b50d-290ed0555ef4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading instacart-market-basket-analysis.zip to /content\n", + " 92% 181M/196M [00:01<00:00, 81.3MB/s]\n", + "100% 196M/196M [00:01<00:00, 105MB/s] \n" + ] + } ], - "text/plain": [ - " product_id product_name \\\n", - "0 46149 Zero Calorie Cola \n", - "1 196 Soda \n", - "2 40939 Drinking Water \n", - "3 41400 Crunchy Oats 'n Honey Granola Bars \n", - "4 37710 Trail Mix \n", - "5 46061 Popcorn \n", - "6 38928 0% Greek Strained Yogurt \n", - "7 31651 Extra Fancy Unsalted Mixed Nuts \n", - "8 22802 Mineral Water \n", - "9 39657 Milk Chocolate Almonds \n", - "\n", - " vector _distance \n", - "0 [0.0022252752, 0.006192103, -0.030976068, -0.0... 42.362167 \n", - "1 [-0.026970126, -0.018141357, -0.058909897, -0.... 42.566917 \n", - "2 [-0.0022220984, 0.0038040192, -0.014334261, -0... 42.823524 \n", - "3 [0.00020495932, 0.01123721, -0.017992454, 0.00... 42.825871 \n", - "4 [-0.00409454, 0.013049758, -0.015458386, 0.002... 42.887043 \n", - "5 [-0.0024577982, 0.0062640505, 0.0007137757, -0... 42.894321 \n", - "6 [0.0039202925, -0.0039743707, -0.012298337, 0.... 42.912727 \n", - "7 [0.0035735483, -0.006829423, -0.009457169, 0.0... 42.922546 \n", - "8 [0.02247884, 0.003889028, -0.020661984, -0.031... 42.935951 \n", - "9 [0.00749532, 0.00577313, -0.016842585, -0.0015... 42.946739 " + "source": [ + "!kaggle competitions download -c instacart-market-basket-analysis" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K4Q4cOX-4_zY" + }, + "source": [ + "We must now extract the zip files." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "f3g296nL4_zZ" + }, + "outputs": [], + "source": [ + "files = [\n", + " 'instacart-market-basket-analysis.zip',\n", + " 'order_products__train.csv.zip',\n", + " 'order_products__prior.csv.zip',\n", + " 'products.csv.zip',\n", + " 'orders.csv.zip'\n", + "]\n", + "\n", + "for filename in files:\n", + " with zipfile.ZipFile(filename, 'r') as zip_ref:\n", + " zip_ref.extractall('./')" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_nametotal_orders
046149Zero Calorie Cola50
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "cell_type": "markdown", + "metadata": { + "id": "oLgkRIfq4_zZ" + }, + "source": [ + "Now we can move on to loading the dataset. We'll first read the csv files and create dataframes." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "cBbbR7Rut5G_" + }, + "outputs": [], + "source": [ + "products = pd.read_csv('products.csv')\n", + "orders = pd.read_csv('orders.csv')\n", + "order_products = pd.concat([pd.read_csv('order_products__train.csv'), pd.read_csv('order_products__prior.csv')])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5FV_GGjst5HA" + }, + "source": [ + "Since there isn't a user rating attribute, we'll gather \"confidence\" data by looking at the frequency of each item purchased by a user, and store this in the `data` dataframe." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YNgjd2nnqNx7" + }, + "source": [ + "### Data Manipulation" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "ZjRh7RYpt5HB" + }, + "outputs": [], + "source": [ + "customer_order_products = pd.merge(orders, order_products, how='inner',on='order_id')\n", + "\n", + "# create confidence table\n", + "data = customer_order_products.groupby(['user_id', 'product_id'])[['order_id']].count().reset_index()\n", + "data.columns=[\"user_id\", \"product_id\", \"total_orders\"]\n", + "data.product_id = data.product_id.astype('int64')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "77lvwm0St5HC" + }, + "source": [ + "Let's create a couple of test users to examine the recommendations later:\n", + "- 1st test user: buys 50 sodas: **Zero Calorie Cola**\n", + "- 2nd test user: buys organic produce: **Organic Whole Milk** and **Organic Blackberries**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "A06EfAf-t5HC", + "outputId": "95a1f51f-ced1-437a-8b62-569bb915262c" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " user_id product_id total_orders\n", + "13863744 206209 48697 1\n", + "13863745 206209 48742 2\n", + "13863746 206210 46149 50\n", + "13863747 206211 27845 49\n", + "13863748 206211 26604 32" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idproduct_idtotal_orders
13863744206209486971
13863745206209487422
138637462062104614950
138637472062112784549
138637482062112660432
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n" + ] + }, + "metadata": {}, + "execution_count": 15 + } ], - "text/plain": [ - " product_id product_name total_orders\n", - "0 46149 Zero Calorie Cola 50" + "source": [ + "data_new = pd.DataFrame([[data.user_id.max() + 1, 46149, 50],\n", + " [data.user_id.max() + 2, 27845, 49],\n", + " [data.user_id.max() + 2, 26604, 32]\n", + " ], columns=['user_id', 'product_id', 'total_orders'])\n", + "data = pd.concat([data, data_new]).reset_index(drop = True)\n", + "data.tail()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_namevector_distance
026604Organic Blackberries[-0.006724186, 0.025339324, 0.026328607, -0.00...19.517328
127966Organic Raspberries[-0.008532436, 0.012350272, 0.00061730895, 0.0...19.688234
29076Blueberries[-0.02710966, 0.04093987, 0.051150266, -0.0465...19.795025
343352Raspberries[-0.00842552, 0.01970873, 0.043075223, -0.0089...19.805353
439275Organic Blueberries[-0.01799259, 0.0049827565, 0.0029076852, 0.02...19.961214
527845Organic Whole Milk[0.0005443055, -0.013880691, 0.008969757, -0.0...19.976284
621288Blackberries[-0.007392233, -0.01224536, 0.03930769, 0.0020...19.990463
711777Red Raspberries[-0.011827968, 0.02923465, 0.006089752, -0.033...20.038776
821137Organic Strawberries[-0.018719932, 0.004096488, -0.016034253, 0.02...20.056273
947209Organic Hass Avocado[0.016230278, 0.0025620027, -0.0056362785, 0.0...20.078579
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n" + "cell_type": "markdown", + "metadata": { + "id": "xBC-8PFTt5HD" + }, + "source": [ + "In the next step, we will extract user and product unique ids, in order to create a `CSR (Compressed Sparse Row)` matrix. This will allow us to perform collaborative filtering.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "v2_2R7zmt5HE" + }, + "outputs": [], + "source": [ + "# extract unique user and product ids\n", + "unique_users = list(np.sort(data.user_id.unique()))\n", + "unique_products = list(np.sort(products.product_id.unique()))\n", + "purchases = list(data.total_orders)\n", + "\n", + "# create zero-based index position <-> user/item ID mappings\n", + "index_to_user = pd.Series(unique_users)\n", + "\n", + "# create reverse mappings from user/item ID to index positions\n", + "user_to_index = pd.Series(data=index_to_user.index + 1, index=index_to_user.values)\n", + "\n", + "# create row and column for user and product ids\n", + "users_rows = data.user_id.astype(int)\n", + "products_cols = data.product_id.astype(int)\n", + "\n", + "# create CSR matrix\n", + "matrix = scipy.sparse.csr_matrix((purchases, (users_rows, products_cols)), shape=(len(unique_users) + 1, len(unique_products) + 1))\n", + "matrix.data = np.nan_to_num(matrix.data, copy=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "II6wOH96t5HF" + }, + "source": [ + "Let's now create a recommender model using the **implicit** library. The recommendation model is based off the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](https://www.researchgate.net/publication/220765111_Collaborative_Filtering_for_Implicit_Feedback_Datasets) with performance optimizations described in [Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.379.6473&rep=rep1&type=pdf).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JDwIxGMnqNx8" + }, + "source": [ + "# Difference between colloborative and content filtering\n", + "\n", + "![picture](https://miro.medium.com/v2/resize:fit:1400/0*R8qw_CXxCc4600bQ.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "51febb09c3d54a1a9cf5dd896f3a24f6", + "91b083fde4f14c39bbafb6fd099d44bd", + "84fca55b676b4ef2add284492c8f4c3c", + "bb2c985a09564562b6f040e31d817f07", + "cc06b425a9364b6eb07ef77c4ff6fc48", + "e2e92925bbb442f8a77e2d55886bfbfa", + "bc7f6859319f455da1f552b66a6cf026", + "66396eb857864cc8af94d7e2ced3102c", + "38ddb81c475a472d8439dcf72261b727", + "c095ad1b03a34c4e8b2077e373c82a5b", + "692c702c31904e058c809ae772f1579a" + ] + }, + "id": "k0GW99kxt5HF", + "outputId": "548c2514-6194-43e4-dd24-6861f1808f5b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 2 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", + " check_blas_config()\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " 0%| | 0/50 [00:00\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_idproduct_nametotal_orders
027845Organic Whole Milk49
126604Organic Blackberries32
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "
\n", - " \n" + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 118, + "referenced_widgets": [ + "ae94cc355e0c4f8b8b73824ae2ef5632", + "b07491e5db2d42b499fce4d7caddfe6f", + "f79ee40b5f854a8b99c57b7c5156d3cd", + "3c2bc0b631644bb992905d55dfe0a7a8", + "41ebafe8393c451f83bfd2132a677a67", + "9b1e85ca94ef442fbd546647f72e6905", + "ae1b3cb276f44f5ebab3eaf8f7b85e67", + "2782769e3daa491385bcc8ae34f24f3b", + "5d41569b941445bea2497c89d3c8e6cb", + "5e7dd2740d174064ac2d1cbc75cb5909", + "a67972dc3f264b3699816257f1ad9ed7" + ] + }, + "id": "BbD8of_nt5HG", + "outputId": "0fd51c13-6aad-408c-8732-f634b900d88e" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + " 0%| | 0/192941 [00:00\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_namevector_distance
046149Zero Calorie Cola[-0.014371638, -0.016776536, -0.026950998, -0....36.209068
1196Soda[-0.031917833, -0.050772455, 0.013827451, -0.0...36.464764
240939Drinking Water[-0.013426425, 0.0053616967, -0.01992105, -0.0...36.504112
322802Mineral Water[-0.0062663523, -0.00076926383, -0.013624842, ...36.615498
437710Trail Mix[-0.01988333, -0.014069387, -0.021995109, -0.0...36.650448
542500Orange & Lemon Flavor Variety Pack Sparkling F...[-0.009584657, -0.023491196, -0.033104196, -0....36.696648
611759Organic Simply Naked Pita Chips[-0.009341286, -0.014609524, -0.0064758006, -0...36.705814
741400Crunchy Oats 'n Honey Granola Bars[-0.013461881, -0.021371827, -0.02064814, -0.0...36.709579
846061Popcorn[0.0019679032, 0.00719048, -0.01262015, -0.005...36.714954
926348Mixed Fruit Fruit Snacks[-0.0017672281, 0.0020188452, 0.012172974, -0....36.716858
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "
\n", + " \n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name total_orders\n", + "0 46149 Zero Calorie Cola 50" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_nametotal_orders
046149Zero Calorie Cola50
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name \\\n", + "0 26604 Organic Blackberries \n", + "1 43352 Raspberries \n", + "2 27845 Organic Whole Milk \n", + "3 21288 Blackberries \n", + "4 27966 Organic Raspberries \n", + "5 9076 Blueberries \n", + "6 11777 Red Raspberries \n", + "7 39275 Organic Blueberries \n", + "8 21137 Organic Strawberries \n", + "9 13176 Bag of Organic Bananas \n", + "\n", + " vector _distance \n", + "0 [0.045252558, 0.04258531, 0.011869884, -0.0111... 17.445852 \n", + "1 [0.059606433, 0.014409931, 0.008712215, -0.007... 17.617174 \n", + "2 [-0.03977351, 0.012210161, 0.024828656, 0.0155... 17.692816 \n", + "3 [0.030181486, 0.049021076, 0.003293778, -0.038... 17.696075 \n", + "4 [0.020116415, 0.045062356, 0.00675044, 0.01640... 17.872534 \n", + "5 [0.0482006, 0.06329333, -0.015093377, 0.000180... 17.879623 \n", + "6 [0.05492493, 0.008120705, 0.020613482, 0.00779... 17.931437 \n", + "7 [0.005109854, 0.032895964, -0.013481544, 0.010... 17.970798 \n", + "8 [0.0017651353, 0.033547334, -0.005775958, 0.02... 17.986570 \n", + "9 [0.004607136, 0.02749164, -0.006206838, 0.0187... 18.092993 " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_namevector_distance
026604Organic Blackberries[0.045252558, 0.04258531, 0.011869884, -0.0111...17.445852
143352Raspberries[0.059606433, 0.014409931, 0.008712215, -0.007...17.617174
227845Organic Whole Milk[-0.03977351, 0.012210161, 0.024828656, 0.0155...17.692816
321288Blackberries[0.030181486, 0.049021076, 0.003293778, -0.038...17.696075
427966Organic Raspberries[0.020116415, 0.045062356, 0.00675044, 0.01640...17.872534
59076Blueberries[0.0482006, 0.06329333, -0.015093377, 0.000180...17.879623
611777Red Raspberries[0.05492493, 0.008120705, 0.020613482, 0.00779...17.931437
739275Organic Blueberries[0.005109854, 0.032895964, -0.013481544, 0.010...17.970798
821137Organic Strawberries[0.0017651353, 0.033547334, -0.005775958, 0.02...17.986570
913176Bag of Organic Bananas[0.004607136, 0.02749164, -0.006206838, 0.0187...18.092993
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + " product_id product_name total_orders\n", + "0 27845 Organic Whole Milk 49\n", + "1 26604 Organic Blackberries 32" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
product_idproduct_nametotal_orders
027845Organic Whole Milk49
126604Organic Blackberries32
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "# Query by user factors\n", + "test_user_embeddings = test_user_factors.tolist()\n", + "for embedding, id in zip(test_user_embeddings, test_user_ids):\n", + " results = tbl.search(embedding).limit(10).to_pandas()\n", + " display(results)\n", + " display(products_bought_by_user_in_the_past(id, top=15))" + ] }, - "bffd8e57b04946ae9289af62a52b9fc8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_95e65e6bceed40db9cea2a0adffed3fb", - "placeholder": "​", - "style": "IPY_MODEL_cbff0accbd0f492597900c93a340d59f", - "value": " 192847/192847 [00:02<00:00, 94872.39it/s]" - } + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "-kWR644v1ZJp" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] }, - "c9110b853367476a8743eea0c522a265": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "cbff0accbd0f492597900c93a340d59f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" }, - "ccd7286cfb904ad8a5486738357a6d39": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "vscode": { + "interpreter": { + "hash": "5fe10bf018ef3e697f9035d60bf60847932a12bface18908407fd371fe880db9" + } }, - "d157f99a712c43cdb7a26c6470d0e615": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "51febb09c3d54a1a9cf5dd896f3a24f6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_91b083fde4f14c39bbafb6fd099d44bd", + "IPY_MODEL_84fca55b676b4ef2add284492c8f4c3c", + "IPY_MODEL_bb2c985a09564562b6f040e31d817f07" + ], + "layout": "IPY_MODEL_cc06b425a9364b6eb07ef77c4ff6fc48" + } + }, + "91b083fde4f14c39bbafb6fd099d44bd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e2e92925bbb442f8a77e2d55886bfbfa", + "placeholder": "​", + "style": "IPY_MODEL_bc7f6859319f455da1f552b66a6cf026", + "value": "100%" + } + }, + "84fca55b676b4ef2add284492c8f4c3c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_66396eb857864cc8af94d7e2ced3102c", + "max": 50, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_38ddb81c475a472d8439dcf72261b727", + "value": 50 + } + }, + "bb2c985a09564562b6f040e31d817f07": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c095ad1b03a34c4e8b2077e373c82a5b", + "placeholder": "​", + "style": "IPY_MODEL_692c702c31904e058c809ae772f1579a", + "value": " 50/50 [15:32<00:00, 18.28s/it]" + } + }, + "cc06b425a9364b6eb07ef77c4ff6fc48": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e2e92925bbb442f8a77e2d55886bfbfa": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bc7f6859319f455da1f552b66a6cf026": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "66396eb857864cc8af94d7e2ced3102c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38ddb81c475a472d8439dcf72261b727": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c095ad1b03a34c4e8b2077e373c82a5b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "692c702c31904e058c809ae772f1579a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ae94cc355e0c4f8b8b73824ae2ef5632": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b07491e5db2d42b499fce4d7caddfe6f", + "IPY_MODEL_f79ee40b5f854a8b99c57b7c5156d3cd", + "IPY_MODEL_3c2bc0b631644bb992905d55dfe0a7a8" + ], + "layout": "IPY_MODEL_41ebafe8393c451f83bfd2132a677a67" + } + }, + "b07491e5db2d42b499fce4d7caddfe6f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9b1e85ca94ef442fbd546647f72e6905", + "placeholder": "​", + "style": "IPY_MODEL_ae1b3cb276f44f5ebab3eaf8f7b85e67", + "value": "100%" + } + }, + "f79ee40b5f854a8b99c57b7c5156d3cd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2782769e3daa491385bcc8ae34f24f3b", + "max": 192941, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5d41569b941445bea2497c89d3c8e6cb", + "value": 192941 + } + }, + "3c2bc0b631644bb992905d55dfe0a7a8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5e7dd2740d174064ac2d1cbc75cb5909", + "placeholder": "​", + "style": "IPY_MODEL_a67972dc3f264b3699816257f1ad9ed7", + "value": " 192941/192941 [02:05<00:00, 1812.48it/s]" + } + }, + "41ebafe8393c451f83bfd2132a677a67": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9b1e85ca94ef442fbd546647f72e6905": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ae1b3cb276f44f5ebab3eaf8f7b85e67": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "2782769e3daa491385bcc8ae34f24f3b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5d41569b941445bea2497c89d3c8e6cb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "5e7dd2740d174064ac2d1cbc75cb5909": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a67972dc3f264b3699816257f1ad9ed7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } } - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file