|
10 | 10 | import tensorflow as tf
|
11 | 11 |
|
12 | 12 | layers = tf.contrib.layers
|
| 13 | +arg_scope = tf.contrib.framework.arg_scope |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class ConcatOpHandlerTest(tf.test.TestCase):
|
@@ -757,5 +758,157 @@ def testGetInputOutputOpSlices(self):
|
757 | 758 | self.mock_op_reg_manager))
|
758 | 759 |
|
759 | 760 |
|
| 761 | +class GroupingConcatOpHandlerTest(tf.test.TestCase): |
| 762 | + |
| 763 | + def _get_scope(self): |
| 764 | + params = { |
| 765 | + 'trainable': True, |
| 766 | + 'normalizer_fn': layers.batch_norm, |
| 767 | + 'normalizer_params': { |
| 768 | + 'scale': True, |
| 769 | + }, |
| 770 | + } |
| 771 | + |
| 772 | + with arg_scope([layers.conv2d], **params) as sc: |
| 773 | + return sc |
| 774 | + |
| 775 | + def setUp(self): |
| 776 | + tf.reset_default_graph() |
| 777 | + |
| 778 | + # This tests 3 Conv2D ops being concatenated. |
| 779 | + inputs = tf.zeros([2, 4, 4, 3]) |
| 780 | + with tf.contrib.framework.arg_scope(self._get_scope()): |
| 781 | + c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1') |
| 782 | + c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') |
| 783 | + c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') |
| 784 | + net = tf.concat([c1, c2, c3], axis=2) |
| 785 | + layers.batch_norm(net) |
| 786 | + |
| 787 | + g = tf.get_default_graph() |
| 788 | + |
| 789 | + # Declare OpSlice and OpGroup for ops of interest. |
| 790 | + self.concat_op = g.get_operation_by_name('concat') |
| 791 | + self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6)) |
| 792 | + self.concat_op_group = orm.OpGroup( |
| 793 | + self.concat_op_slice, |
| 794 | + omit_source_op_slices=[self.concat_op_slice]) |
| 795 | + |
| 796 | + self.relu1_op = g.get_operation_by_name('conv1/Relu') |
| 797 | + self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6)) |
| 798 | + self.relu1_op_group = orm.OpGroup( |
| 799 | + self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) |
| 800 | + |
| 801 | + self.relu2_op = g.get_operation_by_name('conv2/Relu') |
| 802 | + self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) |
| 803 | + self.relu2_op_group = orm.OpGroup( |
| 804 | + self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) |
| 805 | + |
| 806 | + self.relu3_op = g.get_operation_by_name('conv3/Relu') |
| 807 | + self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) |
| 808 | + self.relu3_op_group = orm.OpGroup( |
| 809 | + self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) |
| 810 | + |
| 811 | + self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm') |
| 812 | + self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) |
| 813 | + self.batch_norm_op_group = orm.OpGroup( |
| 814 | + self.batch_norm_op_slice, |
| 815 | + omit_source_op_slices=[self.batch_norm_op_slice]) |
| 816 | + |
| 817 | + self.concat_group = orm.OpGroup( |
| 818 | + op_slice=None, |
| 819 | + op_groups=[ |
| 820 | + self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group, |
| 821 | + self.relu2_op_group, self.relu3_op_group |
| 822 | + ]) |
| 823 | + |
| 824 | + # Create mock OpRegularizerManager with custom mapping of OpSlice and |
| 825 | + # OpGroup. |
| 826 | + self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) |
| 827 | + |
| 828 | + def get_op_slices(op): |
| 829 | + return self.op_slice_dict.get(op, []) |
| 830 | + |
| 831 | + def get_op_group(op_slice): |
| 832 | + return self.op_group_dict.get(op_slice) |
| 833 | + |
| 834 | + self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices |
| 835 | + self.mock_op_reg_manager.get_op_group.side_effect = get_op_group |
| 836 | + self.mock_op_reg_manager.is_source_op.return_value = False |
| 837 | + self.mock_op_reg_manager.is_passthrough.return_value = True |
| 838 | + self.mock_op_reg_manager.ops = [ |
| 839 | + self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, |
| 840 | + self.batch_norm_op] |
| 841 | + |
| 842 | + def test_AssignGroupingOfGroupingConcatNoSlicing(self): |
| 843 | + # In this test, the output op (batch norm) has size 6 and is not sliced. |
| 844 | + # and that input Conv2Ds are all of size 6, and are grouped. |
| 845 | + |
| 846 | + # Map ops to slices. Batch norm op is composed of multiple slices. |
| 847 | + self.op_slice_dict = { |
| 848 | + self.relu1_op: [self.relu1_op_slice], |
| 849 | + self.relu2_op: [self.relu2_op_slice], |
| 850 | + self.relu3_op: [self.relu3_op_slice], |
| 851 | + self.concat_op: [self.concat_op_slice], |
| 852 | + self.batch_norm_op: [self.batch_norm_op_slice], |
| 853 | + } |
| 854 | + |
| 855 | + # Map each slice to a group. |
| 856 | + self.op_group_dict = { |
| 857 | + self.relu1_op_slice: self.relu1_op_group, |
| 858 | + self.relu2_op_slice: self.relu2_op_group, |
| 859 | + self.relu3_op_slice: self.relu3_op_group, |
| 860 | + self.batch_norm_op_slice: self.batch_norm_op_group |
| 861 | + } |
| 862 | + |
| 863 | + # Call handler to assign grouping. |
| 864 | + handler = concat_op_handler.ConcatOpHandler() |
| 865 | + handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) |
| 866 | + |
| 867 | + # Verify manager looks up OpSlice for ops of interest. |
| 868 | + self.mock_op_reg_manager.get_op_slices.assert_has_calls( |
| 869 | + # Checking for ops to process. |
| 870 | + [mock.call(self.relu1_op), |
| 871 | + mock.call(self.relu2_op), |
| 872 | + mock.call(self.relu3_op), |
| 873 | + mock.call(self.batch_norm_op), |
| 874 | + # Initial slice data. |
| 875 | + mock.call(self.concat_op), |
| 876 | + mock.call(self.relu1_op), |
| 877 | + mock.call(self.relu2_op), |
| 878 | + mock.call(self.relu3_op), |
| 879 | + mock.call(self.batch_norm_op), |
| 880 | + # Reslicing. |
| 881 | + mock.call(self.relu1_op), |
| 882 | + mock.call(self.relu2_op), |
| 883 | + mock.call(self.relu3_op), |
| 884 | + mock.call(self.concat_op), |
| 885 | + mock.call(self.batch_norm_op), |
| 886 | + # Refreshing slice data. |
| 887 | + mock.call(self.relu1_op), |
| 888 | + mock.call(self.relu2_op), |
| 889 | + mock.call(self.relu3_op), |
| 890 | + mock.call(self.batch_norm_op), |
| 891 | + # Group concat op. |
| 892 | + mock.call(self.concat_op)]) |
| 893 | + |
| 894 | + # Verify manager does not slices the concat op. |
| 895 | + self.mock_op_reg_manager.slice_op.assert_not_called() |
| 896 | + |
| 897 | + # Verify manager groups the new slices. |
| 898 | + self.mock_op_reg_manager.group_op_slices.assert_called_once_with([ |
| 899 | + self.concat_op_slice, self.relu1_op_slice, self.relu2_op_slice, |
| 900 | + self.relu3_op_slice, self.batch_norm_op_slice |
| 901 | + ]) |
| 902 | + |
| 903 | + def testGetConcatOpAxis(self): |
| 904 | + x = tf.zeros([7, 12, 12, 3]) |
| 905 | + self.assertEqual( |
| 906 | + concat_op_handler._get_concat_op_axis(tf.concat([x, x], 3).op), 3) |
| 907 | + self.assertEqual( |
| 908 | + concat_op_handler._get_concat_op_axis(tf.concat([x, x, x], 1).op), 1) |
| 909 | + self.assertEqual( |
| 910 | + concat_op_handler._get_concat_op_axis(tf.concat([x, x, x], 2).op), 2) |
| 911 | + |
| 912 | + |
760 | 913 | if __name__ == '__main__':
|
761 | 914 | tf.test.main()
|
0 commit comments