Skip to content

Commit ab9eb2d

Browse files
authored
[Enhancement] Support Subclassed Model API and custom objects (#74)
* Saving custom models functionality
1 parent 145df95 commit ab9eb2d

File tree

3 files changed

+453
-192
lines changed

3 files changed

+453
-192
lines changed

examples/models/tensorflow_keras_tiledb_models_example.ipynb

+131-4
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,6 @@
526526
"execution_count": null,
527527
"metadata": {
528528
"pycharm": {
529-
"is_executing": true,
530529
"name": "#%%\n"
531530
}
532531
},
@@ -535,12 +534,140 @@
535534
"tiledb.ls('MNIST_Group', lambda obj_path, obj_type: print(obj_path, obj_type))"
536535
]
537536
},
537+
{
538+
"cell_type": "markdown",
539+
"source": [
540+
"## Model Subclassing\n",
541+
"\n",
542+
"Apart from being able to store models, which have been created with Symbolic APIs\n",
543+
"(Sequential, Functional) someone can store models that are being designed based on\n",
544+
"Imperative API (aka. Model Subclassing).\n",
545+
"\n",
546+
"Let's first design a simple model:"
547+
],
548+
"metadata": {
549+
"collapsed": false,
550+
"pycharm": {
551+
"name": "#%% md\n"
552+
}
553+
}
554+
},
555+
{
556+
"cell_type": "code",
557+
"execution_count": null,
558+
"outputs": [],
559+
"source": [
560+
"from tensorflow import keras\n",
561+
"\n",
562+
"class CustomModel(keras.Model):\n",
563+
" def __init__(self, hidden_units):\n",
564+
" super(CustomModel, self).__init__()\n",
565+
" self.hidden_units = hidden_units\n",
566+
" self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]\n",
567+
"\n",
568+
" def call(self, inputs):\n",
569+
" x = inputs\n",
570+
" for layer in self.dense_layers:\n",
571+
" x = layer(x)\n",
572+
" return x\n",
573+
"\n",
574+
" def get_config(self):\n",
575+
" return {\"hidden_units\": self.hidden_units}\n",
576+
"\n",
577+
" @classmethod\n",
578+
" def from_config(cls, config):\n",
579+
" return cls(**config)"
580+
],
581+
"metadata": {
582+
"collapsed": false,
583+
"pycharm": {
584+
"name": "#%%\n"
585+
}
586+
}
587+
},
588+
{
589+
"cell_type": "markdown",
590+
"source": [
591+
"Then we can create a trivial input dataset for testing the model. Remember that\n",
592+
"for custom models to be initialised they need to be called on data."
593+
],
594+
"metadata": {
595+
"collapsed": false,
596+
"pycharm": {
597+
"name": "#%% md\n"
598+
}
599+
}
600+
},
538601
{
539602
"cell_type": "code",
540603
"execution_count": null,
541-
"metadata": {},
542604
"outputs": [],
543-
"source": []
605+
"source": [
606+
"model = CustomModel([16, 16, 10])\n",
607+
"# Build the model by calling it\n",
608+
"input_arr = tf.random.uniform((1, 5))\n",
609+
"outputs = model(input_arr)"
610+
],
611+
"metadata": {
612+
"collapsed": false,
613+
"pycharm": {
614+
"name": "#%%\n"
615+
}
616+
}
617+
},
618+
{
619+
"cell_type": "markdown",
620+
"source": [
621+
"We then can save the model as a TileDB array."
622+
],
623+
"metadata": {
624+
"collapsed": false,
625+
"pycharm": {
626+
"name": "#%% md\n"
627+
}
628+
}
629+
},
630+
{
631+
"cell_type": "code",
632+
"execution_count": null,
633+
"outputs": [],
634+
"source": [
635+
"tiledb_model_custom = TensorflowKerasTileDBModel(uri='tiledb-keras-custom-model', model=model)\n",
636+
"tiledb_model_custom.save(include_optimizer=True, update=False)\n"
637+
],
638+
"metadata": {
639+
"collapsed": false,
640+
"pycharm": {
641+
"name": "#%%\n"
642+
}
643+
}
644+
},
645+
{
646+
"cell_type": "markdown",
647+
"source": [
648+
"Loading the subclassed model requires `custom_objects` to be passed as an argument\n",
649+
"and the `input_shape` of the model so it can be built. The output of two models are\n",
650+
"exactly the same"
651+
],
652+
"metadata": {
653+
"collapsed": false
654+
}
655+
},
656+
{
657+
"cell_type": "code",
658+
"execution_count": null,
659+
"outputs": [],
660+
"source": [
661+
"loaded_custom = tiledb_model_custom.load(custom_objects={\"CustomModel\": CustomModel}, input_shape=(1, 5))\n",
662+
"outputs_loaded = loaded_custom((1, 5))\n",
663+
"outputs == outputs_loaded"
664+
],
665+
"metadata": {
666+
"collapsed": false,
667+
"pycharm": {
668+
"name": "#%%\n"
669+
}
670+
}
544671
}
545672
],
546673
"metadata": {
@@ -564,4 +691,4 @@
564691
},
565692
"nbformat": 4,
566693
"nbformat_minor": 1
567-
}
694+
}

0 commit comments

Comments
 (0)