Skip to content

Commit 0c3151a

Browse files
committed
Add automatic library installation in after optimization test
1 parent a3f8123 commit 0c3151a

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

Diff for: README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ This code samples show unoptimized/optimized tensorflow workflow.
99

1010
## Requirements
1111

12+
### Hardware Requirements
1213
* x86-64 (AMD64) CPU
1314
* RAM >= 8GiB
1415
* NVIDIA [Computer Capability](https://developer.nvidia.com/cuda-gpus) 7.0+ GPUs
1516
* GPU memory > 12GiB for default batch size
1617

17-
### Test Environment
18+
#### Test Environment
1819
* CPU : Intel(R) Xeon(R) Gold 5218R
1920
* GPU : 2x A100 80GB PCI-E
2021
* RAM : 255GiB
@@ -30,8 +31,8 @@ This code samples show unoptimized/optimized tensorflow workflow.
3031
1. Clone this repo with submodule
3132
```
3233
git clone --recursive https://github.com/ReturnToFirst/FastTFWorkflow.git
33-
```
34-
2. Compare performance between unoptimized/optimized workflow
34+
35+
3. Compare performance between unoptimized/optimized workflow
3536
3637
### For advanced users
3738

Diff for: after_optimization.ipynb

+42-6
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,49 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"import tensorflow as tf\n",
32-
"from nvidia.dali import pipeline_def, fn, types\n",
33-
"import nvidia.dali.plugin.tf as dali_tf\n",
34-
"\n",
31+
"# This package will use for install uninstalled package\n",
32+
"import subprocess\n",
33+
"\n",
34+
"# Import TensorFlow and install if not installed\n",
35+
"try:\n",
36+
" import tensorflow as tf\n",
37+
"except(ModuleNotFoundError):\n",
38+
" subprocess.run([\"pip3\", \"install\", \"tensorflow\"])\n",
39+
" import tensorflow as tf\n",
40+
"\n",
41+
"# Import NVIDIA DALI and install if not installed\n",
42+
"try:\n",
43+
" from nvidia.dali import pipeline_def, fn, types\n",
44+
"except(ModuleNotFoundError):\n",
45+
" import re\n",
46+
"\n",
47+
" nv_smi = subprocess.run([\"nvidia-smi\"], stdout=subprocess.PIPE).stdout\n",
48+
" cuda_version = float(re.search(r'CUDA\\sVersion:\\s+(\\d+\\.\\d+)', nv_smi.decode()).group(1))\n",
49+
" if 12 <= cuda_version:\n",
50+
" subprocess.run(['pip' 'install' '--extra-index-url' 'https://developer.download.nvidia.com/compute/redist' '--upgrade nvidia-dali-cuda120'])\n",
51+
" else:\n",
52+
" subprocess.run(['pip' 'install' '--extra-index-url' 'https://developer.download.nvidia.com/compute/redist' '--upgrade nvidia-dali-cuda110'])\n",
53+
"\n",
54+
" from nvidia.dali import pipeline_def, fn, types\n",
55+
"\n",
56+
"# Import NVIDIA DALI TensorFlow plugin and install if not installed\n",
57+
"try:\n",
58+
" import nvidia.dali.plugin.tf as dali_tf\n",
59+
" \n",
60+
"except(ModuleNotFoundError):\n",
61+
" if 12 <= cuda_version:\n",
62+
" subprocess.run(['pip' 'install' '--extra-index-url' 'https://developer.download.nvidia.com/compute/redist' '--upgrade nvidia-dali-tf-plugin-cuda120'])\n",
63+
" else:\n",
64+
" subprocess.run(['pip' 'install' '--extra-index-url' 'https://developer.download.nvidia.com/compute/redist' '--upgrade nvidia-dali-tf-plugin-cuda110'])\n",
65+
" import nvidia.dali.plugin.tf as dali_tf\n",
66+
"\n",
67+
"\n",
68+
"# Import Python internal modules\n",
3569
"import os\n",
3670
"import glob\n",
37-
"import math\n"
71+
"import math\n",
72+
"import subprocess\n",
73+
"import re"
3874
]
3975
},
4076
{
@@ -296,7 +332,7 @@
296332
"name": "python",
297333
"nbconvert_exporter": "python",
298334
"pygments_lexer": "ipython3",
299-
"version": "3.10.11"
335+
"version": "3.8.16"
300336
},
301337
"orig_nbformat": 4
302338
},

Diff for: after_optimization_multi.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# Example : Optimized Pytorch workflow\n",
7+
"# Example : Optimized TensorFlow workflow\n",
88
"\n",
99
"## Summary\n",
1010
"This example is optimized pytorch training workflow.\n",

0 commit comments

Comments
 (0)