@@ -116,9 +116,6 @@ def SampleEncodeAsIds(self, input, nbest_size, alpha):
116
116
def DecodePieces (self , pieces ):
117
117
return _sentencepiece .SentencePieceProcessor_DecodePieces (self , pieces )
118
118
119
- def DecodeIds (self , ids ):
120
- return _sentencepiece .SentencePieceProcessor_DecodeIds (self , ids )
121
-
122
119
def EncodeAsSerializedProto (self , input ):
123
120
return _sentencepiece .SentencePieceProcessor_EncodeAsSerializedProto (self , input )
124
121
@@ -131,9 +128,6 @@ def NBestEncodeAsSerializedProto(self, input, nbest_size):
131
128
def DecodePiecesAsSerializedProto (self , pieces ):
132
129
return _sentencepiece .SentencePieceProcessor_DecodePiecesAsSerializedProto (self , pieces )
133
130
134
- def DecodeIdsAsSerializedProto (self , ids ):
135
- return _sentencepiece .SentencePieceProcessor_DecodeIdsAsSerializedProto (self , ids )
136
-
137
131
def GetPieceSize (self ):
138
132
return _sentencepiece .SentencePieceProcessor_GetPieceSize (self )
139
133
@@ -176,6 +170,12 @@ def serialized_model_proto(self):
176
170
def LoadFromFile (self , arg ):
177
171
return _sentencepiece .SentencePieceProcessor_LoadFromFile (self , arg )
178
172
173
+ def DecodeIdsWithCheck (self , ids ):
174
+ return _sentencepiece .SentencePieceProcessor_DecodeIdsWithCheck (self , ids )
175
+
176
+ def DecodeIdsAsSerializedProtoWithCheck (self , ids ):
177
+ return _sentencepiece .SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck (self , ids )
178
+
179
179
def Init (self ,
180
180
model_file = None ,
181
181
model_proto = None ,
@@ -310,15 +310,15 @@ def Decode(self, input):
310
310
if not input :
311
311
return self .DecodeIds ([])
312
312
elif type (input ) is int :
313
- return self .DecodeIds ([input ])
313
+ return self .DecodeIdsWithCheck ([input ])
314
314
elif type (input ) is str :
315
315
return self .DecodePieces ([input ])
316
316
317
317
def _decode (input ):
318
318
if not input :
319
319
return self .DecodeIds ([])
320
320
if type (input [0 ]) is int :
321
- return self .DecodeIds (input )
321
+ return self .DecodeIdsWithCheck (input )
322
322
return self .DecodePieces (input )
323
323
324
324
if type (input [0 ]) is list :
@@ -508,6 +508,8 @@ def _batched_func(self, arg):
508
508
509
509
SentencePieceProcessor .Tokenize = SentencePieceProcessor .Encode
510
510
SentencePieceProcessor .Detokenize = SentencePieceProcessor .Decode
511
+ SentencePieceProcessor .DecodeIds = SentencePieceProcessor .DecodeIdsWithCheck
512
+ SentencePieceProcessor .DecodeIdsAsSerializedProto = SentencePieceProcessor .DecodeIdsAsSerializedProtoWithCheck
511
513
512
514
for m in [
513
515
'PieceToId' , 'IdToPiece' , 'GetScore' , 'IsUnknown' , 'IsControl' , 'IsUnused' ,
0 commit comments