diff --git a/docs/physics/montecarlo/initialization.ipynb b/docs/physics/montecarlo/initialization.ipynb index 5508f7c23d3..9086439701c 100644 --- a/docs/physics/montecarlo/initialization.ipynb +++ b/docs/physics/montecarlo/initialization.ipynb @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 10, "id": "426325e5", "metadata": {}, "outputs": [], @@ -85,7 +85,7 @@ "from astropy import units as u\n", "from tardis import constants as const\n", "import matplotlib.pyplot as plt\n", - "from numba import jit\n", + "from numba import njit\n", "from time import time" ] }, @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 3, "id": "bc34bf33", "metadata": {}, "outputs": [], @@ -128,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 4, "id": "3fb3ca8c", "metadata": {}, "outputs": [ @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "id": "925e9e1b", "metadata": {}, "outputs": [ @@ -214,7 +214,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 6, "id": "fed35f47", "metadata": {}, "outputs": [ @@ -251,19 +251,10 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 7, "id": "916a5e22", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_33774/3743613358.py:16: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n", - " @jit\n" - ] - } - ], + "outputs": [], "source": [ "h = const.h.cgs\n", "c2 = const.c.cgs**2\n", @@ -278,10 +269,19 @@ " * h\n", " * nu**3\n", " / (c2 * (np.exp(h * nu / (kB * temperature_inner)) - 1))\n", - " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3f14b73d-a949-4852-b1fc-3301706bf51d", + "metadata": {}, + "outputs": [], + "source": [ "# JIT-compiled function\n", - "@jit\n", - "def jit_planck_function(nu):\n", + "@njit\n", + "def njit_planck_function(nu):\n", " return (\n", " 8\n", " * np.pi\n", @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 15, "id": "d19d0049-e895-409b-94e9-2ef383e86478", "metadata": {}, "outputs": [ @@ -336,14 +336,6 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Execution times for original function: [0.0003864765167236328, 0.0003466606140136719, 0.0004162788391113281]\n", - "Execution times for JIT-compiled function: [0.20883679389953613, 3.933906555175781e-05, 3.695487976074219e-05]\n" - ] } ], "source": [ @@ -369,7 +361,7 @@ " \n", " # Measure execution time for JIT-compiled function\n", " start_time = time()\n", - " luminosity_jitted = jit_planck_function(nus_planck * u.Hz)\n", + " luminosity_jitted = njit_planck_function(nus_planck * u.Hz)\n", " end_time = time()\n", " execution_times_jitted.append(end_time - start_time)\n", " \n", @@ -387,9 +379,79 @@ " plt.legend([\"Original Planck Function\", \"Histogram\"])\n", " plt.show()\n", "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "49927632-053b-455a-b9a2-3d2f4c0f6cb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean and standard deviation of execution times for original function:\n", + "Bin size: 100, Mean time: 0.0006856918334960938, Standard deviation: 0.0003326150415747104\n", + "Bin size: 200, Mean time: 0.0007090330123901367, Standard deviation: 0.00037616331345380154\n", + "Bin size: 500, Mean time: 0.0006082057952880859, Standard deviation: 0.00011986670416905664\n", + "Mean and standard deviation of execution times for njit-compiled function:\n", + "Bin size: 100, Mean time: 6.556510925292969e-05, Standard deviation: 2.453465036165e-05\n", + "Bin size: 200, Mean time: 6.003379821777344e-05, Standard deviation: 9.53423944032505e-06\n", + "Bin size: 500, Mean time: 6.287097930908203e-05, Standard deviation: 1.4083274673289586e-05\n" + ] + } + ], + "source": [ + "# Define bins and frequency range\n", + "bins = [100, 200, 500]\n", + "num_iterations = 10\n", + "execution_times_original = []\n", + "execution_times_njit = []\n", + "\n", + "for num_bins in bins:\n", + " original_times = []\n", + " njit_times = []\n", + " for _ in range(num_iterations):\n", + " # Set up frequency range\n", + " nus_planck = np.linspace(\n", + " min(packet_collection.initial_nus),\n", + " max(packet_collection.initial_nus),\n", + " num_bins\n", + " ).value\n", + " bin_width = nus_planck[1] - nus_planck[0]\n", + "\n", + " # Measure execution time for original function\n", + " start_time = time()\n", + " luminosity_original = planck_function(nus_planck * u.Hz)\n", + " end_time = time()\n", + " original_times.append(end_time - start_time)\n", + "\n", + " # Measure execution time for njit-compiled function\n", + " start_time = time()\n", + " luminosity_njit = njit_planck_function(nus_planck * u.Hz)\n", + " end_time = time()\n", + " njit_times.append(end_time - start_time)\n", + "\n", + " # Calculate mean and standard deviation of execution times\n", + " mean_original_time = np.mean(original_times)\n", + " std_original_time = np.std(original_times)\n", + " mean_njit_time = np.mean(njit_times)\n", + " std_njit_time = np.std(njit_times)\n", + "\n", + " # Store mean execution times\n", + " execution_times_original.append((mean_original_time, std_original_time))\n", + " execution_times_njit.append((mean_njit_time, std_njit_time))\n", + "\n", "# Print execution times\n", - "print(\"Execution times for original function:\", execution_times_original)\n", - "print(\"Execution times for JIT-compiled function:\", execution_times_jitted)\n" + "print(\"Mean and standard deviation of execution times for original function:\")\n", + "for i, (mean_time, std_time) in enumerate(execution_times_original):\n", + " print(f\"Bin size: {bins[i]}, Mean time: {mean_time}, Standard deviation: {std_time}\")\n", + "\n", + "print(\"Mean and standard deviation of execution times for njit-compiled function:\")\n", + "for i, (mean_time, std_time) in enumerate(execution_times_njit):\n", + " print(f\"Bin size: {bins[i]}, Mean time: {mean_time}, Standard deviation: {std_time}\")" ] }, {