Skip to content

Commit 23db4e8

Browse files
Remove update attribute from model save (#200)
* remove update argument from _base.py * remove update argument from pytorch * remove update argument from sklearn * remove update argument from tensorflow * update pytorch model example * update sklearn model example * update tensorflow model example * upgrade isort to 5.11.5 * fixed failed pytorch model test
1 parent 92f594d commit 23db4e8

9 files changed

+126
-130
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repos:
44
hooks:
55
- id: black
66
- repo: https://github.com/pycqa/isort
7-
rev: 5.10.1
7+
rev: 5.11.5
88
hooks:
99
- id: isort
1010
args: ["--profile", "black"]

examples/models/pytorch_tiledb_models_example.ipynb

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"outputs": [
5454
{
5555
"data": {
56-
"text/plain": "<torch._C.Generator at 0x1218c0b70>"
56+
"text/plain": "<torch._C.Generator at 0x1193eb430>"
5757
},
5858
"execution_count": 2,
5959
"metadata": {},
@@ -132,7 +132,20 @@
132132
"text": [
133133
"Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
134134
"\n",
135-
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
135+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n"
136+
]
137+
},
138+
{
139+
"name": "stderr",
140+
"output_type": "stream",
141+
"text": [
142+
"15.9%"
143+
]
144+
},
145+
{
146+
"name": "stdout",
147+
"output_type": "stream",
148+
"text": [
136149
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
137150
]
138151
},
@@ -276,14 +289,6 @@
276289
}
277290
},
278291
"outputs": [
279-
{
280-
"name": "stderr",
281-
"output_type": "stream",
282-
"text": [
283-
"2022-12-07 17:00:23.979857: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
284-
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
285-
]
286-
},
287292
{
288293
"name": "stdout",
289294
"output_type": "stream",
@@ -388,13 +393,21 @@
388393
"name": "#%%\n"
389394
}
390395
},
391-
"outputs": [],
396+
"outputs": [
397+
{
398+
"name": "stderr",
399+
"output_type": "stream",
400+
"text": [
401+
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:410: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
402+
" warnings.warn(\n"
403+
]
404+
}
405+
],
392406
"source": [
393407
"uri = os.path.join(data_home, 'pytorch-mnist-1')\n",
394408
"tiledb_model_1 = PyTorchTileDBModel(uri=uri, model=network, optimizer=optimizer)\n",
395409
"\n",
396-
"tiledb_model_1.save(update=False,\n",
397-
" meta={'epochs': epochs,\n",
410+
"tiledb_model_1.save(meta={'epochs': epochs,\n",
398411
" 'train_loss': train_losses},\n",
399412
" summary_writer=writer)"
400413
]
@@ -438,9 +451,10 @@
438451
")\n",
439452
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
440453
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
454+
"Key: TILEDB_ML_MODEL_VERSION, Value: \n",
441455
"Key: epochs, Value: 1\n",
442-
"Key: model_state_dict_size, Value: 90053\n",
443-
"Key: optimizer_state_dict_size, Value: 90064\n",
456+
"Key: model_size, Value: 90053\n",
457+
"Key: optimizer_size, Value: 90064\n",
444458
"Key: tensorboard_size, Value: 22674\n",
445459
"Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.8843708038330078, 1.7973158359527588, 1.6879109144210815, 1.508046269416809, 1.764279842376709, 1.4700727462768555, 1.3514467477798462, 1.2905819416046143, 1.0177571773529053, 1.042162299156189, 1.0987662076950073, 1.2285516262054443, 1.1495932340621948, 0.8452475070953369, 0.9741130471229553, 0.8569056987762451, 0.9234588146209717, 1.0218565464019775, 0.8069543242454529, 0.8789511919021606, 0.8185049891471863, 0.8055434226989746, 0.8231522440910339, 0.8543609976768494, 0.7746452689170837, 0.718348503112793, 0.5433375239372253, 0.7593768239021301, 0.65492182970047, 0.6999298930168152, 0.8053513765335083, 0.790733814239502, 0.7599329948425293, 0.540409505367279, 0.6412327885627747, 0.6593738198280334)\n"
446460
]
@@ -492,10 +506,11 @@
492506
")\n",
493507
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
494508
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
509+
"Key: TILEDB_ML_MODEL_VERSION, Value: \n",
495510
"Key: epochs, Value: 1\n",
496-
"Key: model_state_dict_size, Value: 90053\n",
511+
"Key: model_size, Value: 90053\n",
497512
"Key: new_meta, Value: [\"Any kind of info\"]\n",
498-
"Key: optimizer_state_dict_size, Value: 90064\n",
513+
"Key: optimizer_size, Value: 90064\n",
499514
"Key: tensorboard_size, Value: 22674\n",
500515
"Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.8843708038330078, 1.7973158359527588, 1.6879109144210815, 1.508046269416809, 1.764279842376709, 1.4700727462768555, 1.3514467477798462, 1.2905819416046143, 1.0177571773529053, 1.042162299156189, 1.0987662076950073, 1.2285516262054443, 1.1495932340621948, 0.8452475070953369, 0.9741130471229553, 0.8569056987762451, 0.9234588146209717, 1.0218565464019775, 0.8069543242454529, 0.8789511919021606, 0.8185049891471863, 0.8055434226989746, 0.8231522440910339, 0.8543609976768494, 0.7746452689170837, 0.718348503112793, 0.5433375239372253, 0.7593768239021301, 0.65492182970047, 0.6999298930168152, 0.8053513765335083, 0.790733814239502, 0.7599329948425293, 0.540409505367279, 0.6412327885627747, 0.6593738198280334)\n"
501516
]
@@ -669,14 +684,14 @@
669684
"number of fragments: 2\n",
670685
"\n",
671686
"===== FRAGMENT NUMBER 0 =====\n",
672-
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1670425246498_1670425246498_5ca20757611a43009e22606647ee9b22_16\n",
673-
"timestamp range: (1670425246498, 1670425246498)\n",
687+
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1675169603990_1675169603990_aae704a1499649fe8ad3fd0e61d8f9b9_16\n",
688+
"timestamp range: (1675169603990, 1675169603990)\n",
674689
"number of unconsolidated metadata: 2\n",
675690
"version: 16\n",
676691
"\n",
677692
"===== FRAGMENT NUMBER 1 =====\n",
678-
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1670425278236_1670425278236_8e60255a3abe4173b21458369995c20c_16\n",
679-
"timestamp range: (1670425278236, 1670425278236)\n",
693+
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1675169635431_1675169635431_e064ef7a982e45a1be7bb678d1949b97_16\n",
694+
"timestamp range: (1675169635431, 1675169635431)\n",
680695
"number of unconsolidated metadata: 2\n",
681696
"version: 16\n"
682697
]
@@ -692,8 +707,7 @@
692707
"\n",
693708
"# and update\n",
694709
"tiledb_model_1 = PyTorchTileDBModel(uri=uri, model=network, optimizer=optimizer)\n",
695-
"tiledb_model_1.save(update=True, \n",
696-
" meta={'epochs': epochs,\n",
710+
"tiledb_model_1.save(meta={'epochs': epochs,\n",
697711
" 'train_loss': train_losses})\n",
698712
"\n",
699713
"# Check array directory\n",
@@ -845,8 +859,7 @@
845859
"uri2 = os.path.join(data_home, 'pytorch-mnist-2')\n",
846860
"tiledb_model_2 = PyTorchTileDBModel(uri=uri2, model=network, optimizer=optimizer)\n",
847861
"\n",
848-
"tiledb_model_2.save(update=False, \n",
849-
" meta={'epochs': epochs,\n",
862+
"tiledb_model_2.save(meta={'epochs': epochs,\n",
850863
" 'train_loss': train_losses})"
851864
]
852865
},

examples/models/sklearn_tiledb_models_example.ipynb

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@
133133
"text": [
134134
"Model fit...\n",
135135
"Model score...\n",
136-
"Sparsity with L1 penalty: 81.12%\n",
137-
"Test score with L1 penalty: 0.8317\n"
136+
"Sparsity with L1 penalty: 80.33%\n",
137+
"Test score with L1 penalty: 0.8401\n"
138138
]
139139
}
140140
],
@@ -178,7 +178,16 @@
178178
"name": "#%%\n"
179179
}
180180
},
181-
"outputs": [],
181+
"outputs": [
182+
{
183+
"name": "stderr",
184+
"output_type": "stream",
185+
"text": [
186+
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:410: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
187+
" warnings.warn(\n"
188+
]
189+
}
190+
],
182191
"source": [
183192
"uri = os.path.join(data_home, 'sklearn-mnist-1')\n",
184193
"tiledb_model_1 = SklearnTileDBModel(uri=uri, model=clf)\n",
@@ -215,14 +224,15 @@
215224
" '../data/sklearn-mnist-1/__schema',\n",
216225
" '../data/sklearn-mnist-1/__fragments']\n",
217226
"\n",
218-
"Key: Sparsity_with_L1_penalty, Value: 81.12244897959184\n",
227+
"Key: Sparsity_with_L1_penalty, Value: 80.33163265306122\n",
219228
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: SKLEARN\n",
220229
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.1.3\n",
221230
"Key: TILEDB_ML_MODEL_PREVIEW, Value: LogisticRegression(C=0.01, penalty='l1', solver='saga', tol=0.1)\n",
222231
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
223232
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
233+
"Key: TILEDB_ML_MODEL_VERSION, Value: \n",
224234
"Key: model_size, Value: 63531\n",
225-
"Key: score, Value: 0.8317\n"
235+
"Key: score, Value: 0.8401\n"
226236
]
227237
}
228238
],
@@ -262,15 +272,16 @@
262272
"name": "stdout",
263273
"output_type": "stream",
264274
"text": [
265-
"Key: Sparsity_with_L1_penalty, Value: 81.12244897959184\n",
275+
"Key: Sparsity_with_L1_penalty, Value: 80.33163265306122\n",
266276
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: SKLEARN\n",
267277
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.1.3\n",
268278
"Key: TILEDB_ML_MODEL_PREVIEW, Value: LogisticRegression(C=0.01, penalty='l1', solver='saga', tol=0.1)\n",
269279
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
270280
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
281+
"Key: TILEDB_ML_MODEL_VERSION, Value: \n",
271282
"Key: model_size, Value: 63531\n",
272283
"Key: new_meta, Value: [\"Any kind of info\"]\n",
273-
"Key: score, Value: 0.8317\n"
284+
"Key: score, Value: 0.8401\n"
274285
]
275286
}
276287
],
@@ -309,12 +320,12 @@
309320
"output_type": "stream",
310321
"text": [
311322
"Model score...\n",
312-
"Sparsity with L1 penalty: 81.12%\n",
313-
"Test score with L1 penalty: 0.8317\n",
323+
"Sparsity with L1 penalty: 80.33%\n",
324+
"Test score with L1 penalty: 0.8401\n",
314325
"Model fit...\n",
315326
"Model score...\n",
316-
"Sparsity with L1 penalty: 45.46%\n",
317-
"Test score with L1 penalty: 0.7301\n",
327+
"Sparsity with L1 penalty: 44.13%\n",
328+
"Test score with L1 penalty: 0.7286\n",
318329
"\n",
319330
"['../data/sklearn-mnist-1/__meta',\n",
320331
" '../data/sklearn-mnist-1/__fragment_meta',\n",
@@ -327,12 +338,12 @@
327338
"number of fragments: 2\n",
328339
"\n",
329340
"===== FRAGMENT NUMBER 0 =====\n",
330-
"timestamp range: (1670425614775, 1670425614775)\n",
341+
"timestamp range: (1675169831049, 1675169831049)\n",
331342
"number of unconsolidated metadata: 2\n",
332343
"version: 16\n",
333344
"\n",
334345
"===== FRAGMENT NUMBER 1 =====\n",
335-
"timestamp range: (1670425618384, 1670425618384)\n",
346+
"timestamp range: (1675169834215, 1675169834215)\n",
336347
"number of unconsolidated metadata: 2\n",
337348
"version: 16\n"
338349
]
@@ -363,8 +374,7 @@
363374
"\n",
364375
"\n",
365376
"tiledb_model_1 = SklearnTileDBModel(uri=uri, model=loaded_clf)\n",
366-
"tiledb_model_1.save(update=True,\n",
367-
" meta={'Sparsity_with_L1_penalty': sparsity,\n",
377+
"tiledb_model_1.save(meta={'Sparsity_with_L1_penalty': sparsity,\n",
368378
" 'score': score})\n",
369379
"\n",
370380
"# Check array directory\n",
@@ -418,7 +428,7 @@
418428
"output_type": "stream",
419429
"text": [
420430
"Fit...\n",
421-
"Test score: 0.7654\n"
431+
"Test score: 0.7741\n"
422432
]
423433
}
424434
],
@@ -526,4 +536,4 @@
526536
},
527537
"nbformat": 4,
528538
"nbformat_minor": 4
529-
}
539+
}

0 commit comments

Comments
 (0)