|
17 | 17 | import math
|
18 | 18 | from typing import Optional, Tuple
|
19 | 19 |
|
20 |
| -from absl import logging |
21 | 20 | import lingvo.compat as tf
|
22 | 21 | from lingvo.core import base_layer
|
23 | 22 | from lingvo.core import gshard_utils
|
@@ -550,11 +549,11 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
|
550 | 549 | )
|
551 | 550 | else:
|
552 | 551 | if hasattr(inputs, 'source_segment_id'):
|
553 |
| - logging.warning( |
| 552 | + tf.logging.warning( |
554 | 553 | 'packed_input is False but source_segment_id is passed.'
|
555 | 554 | )
|
556 | 555 | if hasattr(inputs, 'query_segment_id'):
|
557 |
| - logging.warning( |
| 556 | + tf.logging.warning( |
558 | 557 | 'packed_input is False but query_segment_id is passed.'
|
559 | 558 | )
|
560 | 559 | # Reshape logits to a matrix of shape [target_batch, source_length] and
|
@@ -715,11 +714,11 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
|
715 | 714 | source_padding = tf.squeeze(source_padding, 1)
|
716 | 715 | else:
|
717 | 716 | if hasattr(inputs, 'source_segment_id'):
|
718 |
| - logging.warning( |
| 717 | + tf.logging.warning( |
719 | 718 | 'packed_input is False but source_segment_id is passed.'
|
720 | 719 | )
|
721 | 720 | if hasattr(inputs, 'query_segment_id'):
|
722 |
| - logging.warning( |
| 721 | + tf.logging.warning( |
723 | 722 | 'packed_input is False but query_segment_id is passed.'
|
724 | 723 | )
|
725 | 724 | # [b, sl]
|
@@ -1067,11 +1066,11 @@ def AttenProbs(
|
1067 | 1066 | source_padding = tf.transpose(source_padding, [1, 2, 0])
|
1068 | 1067 | else:
|
1069 | 1068 | if hasattr(inputs, 'source_segment_id'):
|
1070 |
| - logging.warning( |
| 1069 | + tf.logging.warning( |
1071 | 1070 | 'packed_input is False but source_segment_id is passed.'
|
1072 | 1071 | )
|
1073 | 1072 | if hasattr(inputs, 'query_segment_id'):
|
1074 |
| - logging.warning( |
| 1073 | + tf.logging.warning( |
1075 | 1074 | 'packed_input is False but query_segment_id is passed.'
|
1076 | 1075 | )
|
1077 | 1076 | source_padding = tf.transpose(source_padding, [2, 0, 1])
|
|
0 commit comments