@@ -89,7 +89,7 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
8989 metadata ["positions" ] = zip (start_positions , end_positions )
9090
9191 start_positions = [(query_offset + start_position ) for start_position in start_positions ]
92- start_position_labels = torch .zeros (batch_max_length )
92+ start_position_labels = torch .zeros (batch_max_length , dtype = torch . long )
9393
9494 for start_position in start_positions :
9595 if start_position < batch_max_length - 1 :
@@ -98,17 +98,17 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
9898 batch_start_position_labels .append (start_position_labels )
9999
100100 end_positions = [(query_offset + end_position ) for end_position in end_positions ]
101- end_position_labels = torch .zeros (batch_max_length )
101+ end_position_labels = torch .zeros (batch_max_length , dtype = torch . long )
102102
103103 for end_position in end_positions :
104104
105105 if end_position < batch_max_length - 1 :
106106 end_position_labels [end_position ] = 1
107107
108- batch_end_position_labels .append (torch . tensor ( end_position_labels , dtype = torch . long ) )
108+ batch_end_position_labels .append (end_position_labels )
109109
110110 # match position
111- match_positions = torch .zeros (size = (batch_max_length , batch_max_length ))
111+ match_positions = torch .zeros (size = (batch_max_length , batch_max_length ), dtype = torch . long )
112112
113113 for start_position , end_position in zip (start_positions , end_positions ):
114114
0 commit comments