|
17 | 17 | import math |
18 | 18 | from typing import Optional, Tuple |
19 | 19 |
|
20 | | -from absl import logging |
21 | 20 | import lingvo.compat as tf |
22 | 21 | from lingvo.core import base_layer |
23 | 22 | from lingvo.core import gshard_utils |
@@ -550,11 +549,11 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: |
550 | 549 | ) |
551 | 550 | else: |
552 | 551 | if hasattr(inputs, 'source_segment_id'): |
553 | | - logging.warning( |
| 552 | + tf.logging.warning( |
554 | 553 | 'packed_input is False but source_segment_id is passed.' |
555 | 554 | ) |
556 | 555 | if hasattr(inputs, 'query_segment_id'): |
557 | | - logging.warning( |
| 556 | + tf.logging.warning( |
558 | 557 | 'packed_input is False but query_segment_id is passed.' |
559 | 558 | ) |
560 | 559 | # Reshape logits to a matrix of shape [target_batch, source_length] and |
@@ -715,11 +714,11 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: |
715 | 714 | source_padding = tf.squeeze(source_padding, 1) |
716 | 715 | else: |
717 | 716 | if hasattr(inputs, 'source_segment_id'): |
718 | | - logging.warning( |
| 717 | + tf.logging.warning( |
719 | 718 | 'packed_input is False but source_segment_id is passed.' |
720 | 719 | ) |
721 | 720 | if hasattr(inputs, 'query_segment_id'): |
722 | | - logging.warning( |
| 721 | + tf.logging.warning( |
723 | 722 | 'packed_input is False but query_segment_id is passed.' |
724 | 723 | ) |
725 | 724 | # [b, sl] |
@@ -1067,11 +1066,11 @@ def AttenProbs( |
1067 | 1066 | source_padding = tf.transpose(source_padding, [1, 2, 0]) |
1068 | 1067 | else: |
1069 | 1068 | if hasattr(inputs, 'source_segment_id'): |
1070 | | - logging.warning( |
| 1069 | + tf.logging.warning( |
1071 | 1070 | 'packed_input is False but source_segment_id is passed.' |
1072 | 1071 | ) |
1073 | 1072 | if hasattr(inputs, 'query_segment_id'): |
1074 | | - logging.warning( |
| 1073 | + tf.logging.warning( |
1075 | 1074 | 'packed_input is False but query_segment_id is passed.' |
1076 | 1075 | ) |
1077 | 1076 | source_padding = tf.transpose(source_padding, [2, 0, 1]) |
|
0 commit comments