diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
index cd419e672..a85ceb552 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
@@ -50,15 +50,22 @@
import opennlp.tools.util.featuregen.StringPattern;
/**
- * A {@link POSTagger part-of-speech tagger} that uses maximum entropy.
+ * A {@link POSTagger part-of-speech tagger} implementation that uses maximum entropy.
*
- * Tries to predict whether words are nouns, verbs, or any of 70 other POS tags
+ * Tries to predict whether words are nouns, verbs, or any other {@link POSTagFormat POS tags}
* depending on their surrounding context.
+ *
+ * @see POSModel
+ * @see POSTagFormat
+ * @see POSTagger
*/
public class POSTaggerME implements POSTagger {
private static final Logger logger = LoggerFactory.getLogger(POSTaggerME.class);
+ /**
+ * The default beam size value is 3.
+ */
public static final int DEFAULT_BEAM_SIZE = 3;
private final POSModel modelPackage;
@@ -66,7 +73,7 @@ public class POSTaggerME implements POSTagger {
/**
* The {@link POSContextGenerator feature context generator}.
*/
- protected final POSContextGenerator contextGen;
+ protected final POSContextGenerator cg;
/**
* {@link TagDictionary} used for restricting words to a fixed set of tags.
@@ -140,7 +147,7 @@ public POSTaggerME(POSModel model, POSTagFormat format) {
modelPackage = model;
- contextGen = factory.getPOSContextGenerator(beamSize);
+ cg = factory.getPOSContextGenerator(beamSize);
tagDictionary = factory.getTagDictionary();
size = beamSize;
@@ -165,14 +172,20 @@ public String[] getAllPosTags() {
return model.getOutcomes();
}
+ /**
+ * {@inheritDoc}
+ */
@Override
public String[] tag(String[] sentence) {
return this.tag(sentence, null);
}
+ /**
+ * {@inheritDoc}
+ */
@Override
public String[] tag(String[] sentence, Object[] additionalContext) {
- bestSequence = model.bestSequence(sentence, additionalContext, contextGen, sequenceValidator);
+ bestSequence = model.bestSequence(sentence, additionalContext, cg, sequenceValidator);
final List t = bestSequence.getOutcomes();
return convertTags(t);
}
@@ -186,7 +199,7 @@ public String[] tag(String[] sentence, Object[] additionalContext) {
*/
public String[][] tag(int numTaggings, String[] sentence) {
Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null,
- contextGen, sequenceValidator);
+ cg, sequenceValidator);
String[][] tags = new String[bestSequences.length][];
for (int si = 0; si < tags.length; si++) {
List t = bestSequences[si].getOutcomes();
@@ -204,18 +217,25 @@ private String[] convertTags(List t) {
}
}
+ /**
+ * {@inheritDoc}
+ */
@Override
public Sequence[] topKSequences(String[] sentence) {
return this.topKSequences(sentence, null);
}
+ /**
+ * {@inheritDoc}
+ */
@Override
public Sequence[] topKSequences(String[] sentence, Object[] additionalContext) {
- return model.bestSequences(size, sentence, additionalContext, contextGen, sequenceValidator);
+ return model.bestSequences(size, sentence, additionalContext, cg, sequenceValidator);
}
/**
- * Populates the specified array with the probabilities for each tag of the last tagged sentence.
+ * Populates the specified {@code probs} array with the probabilities
+ * for each tag of the last tagged sentence.
*
* @param probs An array to put the probabilities into.
*/
@@ -239,7 +259,7 @@ public String[] getOrderedTags(List words, List tags, int index,
MaxentModel posModel = modelPackage.getArtifact(POSModel.POS_MODEL_ENTRY_NAME);
if (posModel != null) {
- double[] probs = posModel.eval(contextGen.getContext(index, words.toArray(new String[0]),
+ double[] probs = posModel.eval(cg.getContext(index, words.toArray(new String[0]),
tags.toArray(new String[0]), null));
String[] orderedTags = new String[probs.length];
@@ -263,34 +283,44 @@ public String[] getOrderedTags(List words, List tags, int index,
}
}
- public static POSModel train(String languageCode,
- ObjectStream samples, TrainingParameters trainParams,
- POSTaggerFactory posFactory) throws IOException {
-
- int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);
-
- POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
+ /**
+ * Starts a training of a {@link POSModel} with the given parameters.
+ *
+ * @param languageCode The ISO language code to train the model. Must not be {@code null}.
+ * @param samples The {@link ObjectStream} of {@link POSSample} used as input for training.
+ * @param mlParams The {@link TrainingParameters} for the context of the training process.
+ * @param posFactory The {@link POSTaggerFactory} for creating related objects as defined
+ * via {@code mlParams}.
+ *
+ * @return A valid, trained {@link POSModel} instance.
+ * @throws IOException Thrown if IO errors occurred.
+ */
+ public static POSModel train(String languageCode, ObjectStream samples,
+ TrainingParameters mlParams, POSTaggerFactory posFactory)
+ throws IOException {
- Map manifestInfoEntries = new HashMap<>();
+ final int beamSize = mlParams.getIntParameter(
+ BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);
- TrainerType trainerType = TrainerFactory.getTrainerType(trainParams);
+ final POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
+ final TrainerType trainerType = TrainerFactory.getTrainerType(mlParams);
+ final Map manifestInfoEntries = new HashMap<>();
MaxentModel posModel = null;
SequenceClassificationModel seqPosModel = null;
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream es = new POSSampleEventStream(samples, contextGenerator);
- EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams,
- manifestInfoEntries);
+ EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries);
posModel = trainer.train(es);
} else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
EventModelSequenceTrainer trainer =
- TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
+ TrainerFactory.getEventModelSequenceTrainer(mlParams, manifestInfoEntries);
posModel = trainer.train(ss);
} else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
- trainParams, manifestInfoEntries);
+ mlParams, manifestInfoEntries);
// TODO: This will probably cause issue, since the feature generator uses the outcomes array