526
526
"execution_count" : null ,
527
527
"metadata" : {
528
528
"pycharm" : {
529
- "is_executing" : true ,
530
529
"name" : " #%%\n "
531
530
}
532
531
},
535
534
" tiledb.ls('MNIST_Group', lambda obj_path, obj_type: print(obj_path, obj_type))"
536
535
]
537
536
},
537
+ {
538
+ "cell_type" : " markdown" ,
539
+ "source" : [
540
+ " ## Model Subclassing\n " ,
541
+ " \n " ,
542
+ " Apart from being able to store models, which have been created with Symbolic APIs\n " ,
543
+ " (Sequential, Functional) someone can store models that are being designed based on\n " ,
544
+ " Imperative API (aka. Model Subclassing).\n " ,
545
+ " \n " ,
546
+ " Let's first design a simple model:"
547
+ ],
548
+ "metadata" : {
549
+ "collapsed" : false ,
550
+ "pycharm" : {
551
+ "name" : " #%% md\n "
552
+ }
553
+ }
554
+ },
555
+ {
556
+ "cell_type" : " code" ,
557
+ "execution_count" : null ,
558
+ "outputs" : [],
559
+ "source" : [
560
+ " from tensorflow import keras\n " ,
561
+ " \n " ,
562
+ " class CustomModel(keras.Model):\n " ,
563
+ " def __init__(self, hidden_units):\n " ,
564
+ " super(CustomModel, self).__init__()\n " ,
565
+ " self.hidden_units = hidden_units\n " ,
566
+ " self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]\n " ,
567
+ " \n " ,
568
+ " def call(self, inputs):\n " ,
569
+ " x = inputs\n " ,
570
+ " for layer in self.dense_layers:\n " ,
571
+ " x = layer(x)\n " ,
572
+ " return x\n " ,
573
+ " \n " ,
574
+ " def get_config(self):\n " ,
575
+ " return {\" hidden_units\" : self.hidden_units}\n " ,
576
+ " \n " ,
577
+ " @classmethod\n " ,
578
+ " def from_config(cls, config):\n " ,
579
+ " return cls(**config)"
580
+ ],
581
+ "metadata" : {
582
+ "collapsed" : false ,
583
+ "pycharm" : {
584
+ "name" : " #%%\n "
585
+ }
586
+ }
587
+ },
588
+ {
589
+ "cell_type" : " markdown" ,
590
+ "source" : [
591
+ " Then we can create a trivial input dataset for testing the model. Remember that\n " ,
592
+ " for custom models to be initialised they need to be called on data."
593
+ ],
594
+ "metadata" : {
595
+ "collapsed" : false ,
596
+ "pycharm" : {
597
+ "name" : " #%% md\n "
598
+ }
599
+ }
600
+ },
538
601
{
539
602
"cell_type" : " code" ,
540
603
"execution_count" : null ,
541
- "metadata" : {},
542
604
"outputs" : [],
543
- "source" : []
605
+ "source" : [
606
+ " model = CustomModel([16, 16, 10])\n " ,
607
+ " # Build the model by calling it\n " ,
608
+ " input_arr = tf.random.uniform((1, 5))\n " ,
609
+ " outputs = model(input_arr)"
610
+ ],
611
+ "metadata" : {
612
+ "collapsed" : false ,
613
+ "pycharm" : {
614
+ "name" : " #%%\n "
615
+ }
616
+ }
617
+ },
618
+ {
619
+ "cell_type" : " markdown" ,
620
+ "source" : [
621
+ " We then can save the model as a TileDB array."
622
+ ],
623
+ "metadata" : {
624
+ "collapsed" : false ,
625
+ "pycharm" : {
626
+ "name" : " #%% md\n "
627
+ }
628
+ }
629
+ },
630
+ {
631
+ "cell_type" : " code" ,
632
+ "execution_count" : null ,
633
+ "outputs" : [],
634
+ "source" : [
635
+ " tiledb_model_custom = TensorflowKerasTileDBModel(uri='tiledb-keras-custom-model', model=model)\n " ,
636
+ " tiledb_model_custom.save(include_optimizer=True, update=False)\n "
637
+ ],
638
+ "metadata" : {
639
+ "collapsed" : false ,
640
+ "pycharm" : {
641
+ "name" : " #%%\n "
642
+ }
643
+ }
644
+ },
645
+ {
646
+ "cell_type" : " markdown" ,
647
+ "source" : [
648
+ " Loading the subclassed model requires `custom_objects` to be passed as an argument\n " ,
649
+ " and the `input_shape` of the model so it can be built. The output of two models are\n " ,
650
+ " exactly the same"
651
+ ],
652
+ "metadata" : {
653
+ "collapsed" : false
654
+ }
655
+ },
656
+ {
657
+ "cell_type" : " code" ,
658
+ "execution_count" : null ,
659
+ "outputs" : [],
660
+ "source" : [
661
+ " loaded_custom = tiledb_model_custom.load(custom_objects={\" CustomModel\" : CustomModel}, input_shape=(1, 5))\n " ,
662
+ " outputs_loaded = loaded_custom((1, 5))\n " ,
663
+ " outputs == outputs_loaded"
664
+ ],
665
+ "metadata" : {
666
+ "collapsed" : false ,
667
+ "pycharm" : {
668
+ "name" : " #%%\n "
669
+ }
670
+ }
544
671
}
545
672
],
546
673
"metadata" : {
564
691
},
565
692
"nbformat" : 4 ,
566
693
"nbformat_minor" : 1
567
- }
694
+ }
0 commit comments