Skip to content

Commit 0e6dfbf

Browse files
committed
fixed python module to check the id range.
1 parent 3589bfb commit 0e6dfbf

File tree

3 files changed

+177
-131
lines changed

3 files changed

+177
-131
lines changed

Diff for: python/src/sentencepiece/__init__.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ def SampleEncodeAsIds(self, input, nbest_size, alpha):
116116
def DecodePieces(self, pieces):
117117
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)
118118

119-
def DecodeIds(self, ids):
120-
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)
121-
122119
def EncodeAsSerializedProto(self, input):
123120
return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input)
124121

@@ -131,9 +128,6 @@ def NBestEncodeAsSerializedProto(self, input, nbest_size):
131128
def DecodePiecesAsSerializedProto(self, pieces):
132129
return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces)
133130

134-
def DecodeIdsAsSerializedProto(self, ids):
135-
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids)
136-
137131
def GetPieceSize(self):
138132
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
139133

@@ -176,6 +170,12 @@ def serialized_model_proto(self):
176170
def LoadFromFile(self, arg):
177171
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
178172

173+
def DecodeIdsWithCheck(self, ids):
174+
return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids)
175+
176+
def DecodeIdsAsSerializedProtoWithCheck(self, ids):
177+
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids)
178+
179179
def Init(self,
180180
model_file=None,
181181
model_proto=None,
@@ -310,15 +310,15 @@ def Decode(self, input):
310310
if not input:
311311
return self.DecodeIds([])
312312
elif type(input) is int:
313-
return self.DecodeIds([input])
313+
return self.DecodeIdsWithCheck([input])
314314
elif type(input) is str:
315315
return self.DecodePieces([input])
316316

317317
def _decode(input):
318318
if not input:
319319
return self.DecodeIds([])
320320
if type(input[0]) is int:
321-
return self.DecodeIds(input)
321+
return self.DecodeIdsWithCheck(input)
322322
return self.DecodePieces(input)
323323

324324
if type(input[0]) is list:
@@ -508,6 +508,8 @@ def _batched_func(self, arg):
508508

509509
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
510510
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
511+
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
512+
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck
511513

512514
for m in [
513515
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',

Diff for: python/src/sentencepiece/sentencepiece.i

+28-2
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
176176
%ignore sentencepiece::SentencePieceProcessor::SampleEncode;
177177
%ignore sentencepiece::SentencePieceProcessor::NBestEncode;
178178
%ignore sentencepiece::SentencePieceProcessor::Decode;
179+
%ignore sentencepiece::SentencePieceProcessor::DecodeIds;
180+
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsSerializedProto;
179181
%ignore sentencepiece::SentencePieceProcessor::model_proto;
180182
%ignore sentencepiece::SentencePieceProcessor::Load;
181183
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
@@ -196,6 +198,28 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
196198
return $self->Load(arg);
197199
}
198200

201+
std::string DecodeIdsWithCheck(
202+
const std::vector<int> &ids) const {
203+
const int num_pieces = $self->GetPieceSize();
204+
for (int id : ids)
205+
if (id < 0 || id >= num_pieces)
206+
throw sentencepiece::util::Status(
207+
sentencepiece::util::StatusCode::kOutOfRange,
208+
"piece id is out of range.");
209+
return $self->DecodeIds(ids);
210+
}
211+
212+
util::bytes DecodeIdsAsSerializedProtoWithCheck(
213+
const std::vector<int> &ids) const {
214+
const int num_pieces = $self->GetPieceSize();
215+
for (int id : ids)
216+
if (id < 0 || id >= num_pieces)
217+
throw sentencepiece::util::Status(
218+
sentencepiece::util::StatusCode::kOutOfRange,
219+
"piece id is out of range.");
220+
return $self->DecodeIdsAsSerializedProto(ids);
221+
}
222+
199223
%pythoncode {
200224
def Init(self,
201225
model_file=None,
@@ -331,15 +355,15 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
331355
if not input:
332356
return self.DecodeIds([])
333357
elif type(input) is int:
334-
return self.DecodeIds([input])
358+
return self.DecodeIdsWithCheck([input])
335359
elif type(input) is str:
336360
return self.DecodePieces([input])
337361

338362
def _decode(input):
339363
if not input:
340364
return self.DecodeIds([])
341365
if type(input[0]) is int:
342-
return self.DecodeIds(input)
366+
return self.DecodeIdsWithCheck(input)
343367
return self.DecodePieces(input)
344368

345369
if type(input[0]) is list:
@@ -707,6 +731,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
707731

708732
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
709733
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
734+
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
735+
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck
710736

711737
for m in [
712738
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',

0 commit comments

Comments
 (0)