@@ -401,18 +401,8 @@ def __init__(self,
401
401
self .bns = []
402
402
self .grad_checkpoint = grad_checkpoint
403
403
self .feature_only = feature_only
404
- if separable_conv :
405
- conv2d_layer = functools .partial (
406
- tf .keras .layers .SeparableConv2D ,
407
- depth_multiplier = 1 ,
408
- data_format = data_format ,
409
- pointwise_initializer = tf .initializers .variance_scaling (),
410
- depthwise_initializer = tf .initializers .variance_scaling ())
411
- else :
412
- conv2d_layer = functools .partial (
413
- tf .keras .layers .Conv2D ,
414
- data_format = data_format ,
415
- kernel_initializer = tf .random_normal_initializer (stddev = 0.01 ))
404
+
405
+ conv2d_layer = self .conv2d_layer (separable_conv , data_format )
416
406
for i in range (self .repeats ):
417
407
# If using SeparableConv2D
418
408
self .conv_ops .append (
@@ -435,12 +425,8 @@ def __init__(self,
435
425
))
436
426
self .bns .append (bn_per_level )
437
427
438
- self .classes = conv2d_layer (
439
- num_classes * num_anchors ,
440
- kernel_size = 3 ,
441
- bias_initializer = tf .constant_initializer (- np .log ((1 - 0.01 ) / 0.01 )),
442
- padding = 'same' ,
443
- name = 'class-predict' )
428
+ self .classes = self .classes_layer (
429
+ conv2d_layer , num_classes , num_anchors , name = 'class-predict' )
444
430
445
431
@tf .autograph .experimental .do_not_convert
446
432
def _conv_bn_act (self , image , i , level_id , training ):
@@ -476,6 +462,33 @@ def call(self, inputs, training, **kwargs):
476
462
477
463
return class_outputs
478
464
465
+ @classmethod
466
+ def conv2d_layer (cls , separable_conv , data_format ):
467
+ """Gets the conv2d layer in ClassNet class."""
468
+ if separable_conv :
469
+ conv2d_layer = functools .partial (
470
+ tf .keras .layers .SeparableConv2D ,
471
+ depth_multiplier = 1 ,
472
+ data_format = data_format ,
473
+ pointwise_initializer = tf .initializers .variance_scaling (),
474
+ depthwise_initializer = tf .initializers .variance_scaling ())
475
+ else :
476
+ conv2d_layer = functools .partial (
477
+ tf .keras .layers .Conv2D ,
478
+ data_format = data_format ,
479
+ kernel_initializer = tf .random_normal_initializer (stddev = 0.01 ))
480
+ return conv2d_layer
481
+
482
+ @classmethod
483
+ def classes_layer (cls , conv2d_layer , num_classes , num_anchors , name ):
484
+ """Gets the classes layer in ClassNet class."""
485
+ return conv2d_layer (
486
+ num_classes * num_anchors ,
487
+ kernel_size = 3 ,
488
+ bias_initializer = tf .constant_initializer (- np .log ((1 - 0.01 ) / 0.01 )),
489
+ padding = 'same' ,
490
+ name = name )
491
+
479
492
480
493
class BoxNet (tf .keras .layers .Layer ):
481
494
"""Box regression network."""
@@ -574,28 +587,8 @@ def __init__(self,
574
587
name = 'box-%d-bn-%d' % (i , level )))
575
588
self .bns .append (bn_per_level )
576
589
577
- if self .separable_conv :
578
- self .boxes = tf .keras .layers .SeparableConv2D (
579
- filters = 4 * self .num_anchors ,
580
- depth_multiplier = 1 ,
581
- pointwise_initializer = tf .initializers .variance_scaling (),
582
- depthwise_initializer = tf .initializers .variance_scaling (),
583
- data_format = self .data_format ,
584
- kernel_size = 3 ,
585
- activation = None ,
586
- bias_initializer = tf .zeros_initializer (),
587
- padding = 'same' ,
588
- name = 'box-predict' )
589
- else :
590
- self .boxes = tf .keras .layers .Conv2D (
591
- filters = 4 * self .num_anchors ,
592
- kernel_initializer = tf .random_normal_initializer (stddev = 0.01 ),
593
- data_format = self .data_format ,
594
- kernel_size = 3 ,
595
- activation = None ,
596
- bias_initializer = tf .zeros_initializer (),
597
- padding = 'same' ,
598
- name = 'box-predict' )
590
+ self .boxes = self .boxes_layer (
591
+ separable_conv , num_anchors , data_format , name = 'box-predict' )
599
592
600
593
@tf .autograph .experimental .do_not_convert
601
594
def _conv_bn_act (self , image , i , level_id , training ):
@@ -632,6 +625,32 @@ def call(self, inputs, training):
632
625
633
626
return box_outputs
634
627
628
+ @classmethod
629
+ def boxes_layer (cls , separable_conv , num_anchors , data_format , name ):
630
+ """Gets the conv2d layer in BoxNet class."""
631
+ if separable_conv :
632
+ return tf .keras .layers .SeparableConv2D (
633
+ filters = 4 * num_anchors ,
634
+ depth_multiplier = 1 ,
635
+ pointwise_initializer = tf .initializers .variance_scaling (),
636
+ depthwise_initializer = tf .initializers .variance_scaling (),
637
+ data_format = data_format ,
638
+ kernel_size = 3 ,
639
+ activation = None ,
640
+ bias_initializer = tf .zeros_initializer (),
641
+ padding = 'same' ,
642
+ name = name )
643
+ else :
644
+ return tf .keras .layers .Conv2D (
645
+ filters = 4 * num_anchors ,
646
+ kernel_initializer = tf .random_normal_initializer (stddev = 0.01 ),
647
+ data_format = data_format ,
648
+ kernel_size = 3 ,
649
+ activation = None ,
650
+ bias_initializer = tf .zeros_initializer (),
651
+ padding = 'same' ,
652
+ name = name )
653
+
635
654
636
655
class SegmentationHead (tf .keras .layers .Layer ):
637
656
"""Keras layer for semantic segmentation head."""
0 commit comments