Skip to content

Commit

Permalink
systematic execution timing in the initialization.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
leilayesufu committed Mar 7, 2024
1 parent a2c1ac7 commit 941e376
Showing 1 changed file with 94 additions and 32 deletions.
126 changes: 94 additions & 32 deletions docs/physics/montecarlo/initialization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 10,
"id": "426325e5",
"metadata": {},
"outputs": [],
Expand All @@ -85,7 +85,7 @@
"from astropy import units as u\n",
"from tardis import constants as const\n",
"import matplotlib.pyplot as plt\n",
"from numba import jit\n",
"from numba import njit\n",
"from time import time"
]
},
Expand All @@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 3,
"id": "bc34bf33",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -128,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 4,
"id": "3fb3ca8c",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -169,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 5,
"id": "925e9e1b",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -214,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 6,
"id": "fed35f47",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -251,19 +251,10 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 7,
"id": "916a5e22",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_33774/3743613358.py:16: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n",
" @jit\n"
]
}
],
"outputs": [],
"source": [
"h = const.h.cgs\n",
"c2 = const.c.cgs**2\n",
Expand All @@ -278,10 +269,19 @@
" * h\n",
" * nu**3\n",
" / (c2 * (np.exp(h * nu / (kB * temperature_inner)) - 1))\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3f14b73d-a949-4852-b1fc-3301706bf51d",
"metadata": {},
"outputs": [],
"source": [
"# JIT-compiled function\n",
"@jit\n",
"def jit_planck_function(nu):\n",
"@njit\n",
"def njit_planck_function(nu):\n",
" return (\n",
" 8\n",
" * np.pi\n",
Expand All @@ -303,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 15,
"id": "d19d0049-e895-409b-94e9-2ef383e86478",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -336,14 +336,6 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Execution times for original function: [0.0003864765167236328, 0.0003466606140136719, 0.0004162788391113281]\n",
"Execution times for JIT-compiled function: [0.20883679389953613, 3.933906555175781e-05, 3.695487976074219e-05]\n"
]
}
],
"source": [
Expand All @@ -369,7 +361,7 @@
" \n",
" # Measure execution time for JIT-compiled function\n",
" start_time = time()\n",
" luminosity_jitted = jit_planck_function(nus_planck * u.Hz)\n",
" luminosity_jitted = njit_planck_function(nus_planck * u.Hz)\n",
" end_time = time()\n",
" execution_times_jitted.append(end_time - start_time)\n",
" \n",
Expand All @@ -387,9 +379,79 @@
" plt.legend([\"Original Planck Function\", \"Histogram\"])\n",
" plt.show()\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "49927632-053b-455a-b9a2-3d2f4c0f6cb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean and standard deviation of execution times for original function:\n",
"Bin size: 100, Mean time: 0.0006856918334960938, Standard deviation: 0.0003326150415747104\n",
"Bin size: 200, Mean time: 0.0007090330123901367, Standard deviation: 0.00037616331345380154\n",
"Bin size: 500, Mean time: 0.0006082057952880859, Standard deviation: 0.00011986670416905664\n",
"Mean and standard deviation of execution times for njit-compiled function:\n",
"Bin size: 100, Mean time: 6.556510925292969e-05, Standard deviation: 2.453465036165e-05\n",
"Bin size: 200, Mean time: 6.003379821777344e-05, Standard deviation: 9.53423944032505e-06\n",
"Bin size: 500, Mean time: 6.287097930908203e-05, Standard deviation: 1.4083274673289586e-05\n"
]
}
],
"source": [
"# Define bins and frequency range\n",
"bins = [100, 200, 500]\n",
"num_iterations = 10\n",
"execution_times_original = []\n",
"execution_times_njit = []\n",
"\n",
"for num_bins in bins:\n",
" original_times = []\n",
" njit_times = []\n",
" for _ in range(num_iterations):\n",
" # Set up frequency range\n",
" nus_planck = np.linspace(\n",
" min(packet_collection.initial_nus),\n",
" max(packet_collection.initial_nus),\n",
" num_bins\n",
" ).value\n",
" bin_width = nus_planck[1] - nus_planck[0]\n",
"\n",
" # Measure execution time for original function\n",
" start_time = time()\n",
" luminosity_original = planck_function(nus_planck * u.Hz)\n",
" end_time = time()\n",
" original_times.append(end_time - start_time)\n",
"\n",
" # Measure execution time for njit-compiled function\n",
" start_time = time()\n",
" luminosity_njit = njit_planck_function(nus_planck * u.Hz)\n",
" end_time = time()\n",
" njit_times.append(end_time - start_time)\n",
"\n",
" # Calculate mean and standard deviation of execution times\n",
" mean_original_time = np.mean(original_times)\n",
" std_original_time = np.std(original_times)\n",
" mean_njit_time = np.mean(njit_times)\n",
" std_njit_time = np.std(njit_times)\n",
"\n",
" # Store mean execution times\n",
" execution_times_original.append((mean_original_time, std_original_time))\n",
" execution_times_njit.append((mean_njit_time, std_njit_time))\n",
"\n",
"# Print execution times\n",
"print(\"Execution times for original function:\", execution_times_original)\n",
"print(\"Execution times for JIT-compiled function:\", execution_times_jitted)\n"
"print(\"Mean and standard deviation of execution times for original function:\")\n",
"for i, (mean_time, std_time) in enumerate(execution_times_original):\n",
" print(f\"Bin size: {bins[i]}, Mean time: {mean_time}, Standard deviation: {std_time}\")\n",
"\n",
"print(\"Mean and standard deviation of execution times for njit-compiled function:\")\n",
"for i, (mean_time, std_time) in enumerate(execution_times_njit):\n",
" print(f\"Bin size: {bins[i]}, Mean time: {mean_time}, Standard deviation: {std_time}\")"
]
},
{
Expand Down

0 comments on commit 941e376

Please sign in to comment.