diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java index cd419e672..a85ceb552 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java @@ -50,15 +50,22 @@ import opennlp.tools.util.featuregen.StringPattern; /** - * A {@link POSTagger part-of-speech tagger} that uses maximum entropy. + * A {@link POSTagger part-of-speech tagger} implementation that uses maximum entropy. *

- * Tries to predict whether words are nouns, verbs, or any of 70 other POS tags + * Tries to predict whether words are nouns, verbs, or any other {@link POSTagFormat POS tags} * depending on their surrounding context. + * + * @see POSModel + * @see POSTagFormat + * @see POSTagger */ public class POSTaggerME implements POSTagger { private static final Logger logger = LoggerFactory.getLogger(POSTaggerME.class); + /** + * The default beam size value is 3. + */ public static final int DEFAULT_BEAM_SIZE = 3; private final POSModel modelPackage; @@ -66,7 +73,7 @@ public class POSTaggerME implements POSTagger { /** * The {@link POSContextGenerator feature context generator}. */ - protected final POSContextGenerator contextGen; + protected final POSContextGenerator cg; /** * {@link TagDictionary} used for restricting words to a fixed set of tags. @@ -140,7 +147,7 @@ public POSTaggerME(POSModel model, POSTagFormat format) { modelPackage = model; - contextGen = factory.getPOSContextGenerator(beamSize); + cg = factory.getPOSContextGenerator(beamSize); tagDictionary = factory.getTagDictionary(); size = beamSize; @@ -165,14 +172,20 @@ public String[] getAllPosTags() { return model.getOutcomes(); } + /** + * {@inheritDoc} + */ @Override public String[] tag(String[] sentence) { return this.tag(sentence, null); } + /** + * {@inheritDoc} + */ @Override public String[] tag(String[] sentence, Object[] additionalContext) { - bestSequence = model.bestSequence(sentence, additionalContext, contextGen, sequenceValidator); + bestSequence = model.bestSequence(sentence, additionalContext, cg, sequenceValidator); final List t = bestSequence.getOutcomes(); return convertTags(t); } @@ -186,7 +199,7 @@ public String[] tag(String[] sentence, Object[] additionalContext) { */ public String[][] tag(int numTaggings, String[] sentence) { Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null, - contextGen, sequenceValidator); + cg, sequenceValidator); String[][] tags = new String[bestSequences.length][]; for (int si = 0; si < tags.length; si++) { List t = bestSequences[si].getOutcomes(); @@ -204,18 +217,25 @@ private String[] convertTags(List t) { } } + /** + * {@inheritDoc} + */ @Override public Sequence[] topKSequences(String[] sentence) { return this.topKSequences(sentence, null); } + /** + * {@inheritDoc} + */ @Override public Sequence[] topKSequences(String[] sentence, Object[] additionalContext) { - return model.bestSequences(size, sentence, additionalContext, contextGen, sequenceValidator); + return model.bestSequences(size, sentence, additionalContext, cg, sequenceValidator); } /** - * Populates the specified array with the probabilities for each tag of the last tagged sentence. + * Populates the specified {@code probs} array with the probabilities + * for each tag of the last tagged sentence. * * @param probs An array to put the probabilities into. */ @@ -239,7 +259,7 @@ public String[] getOrderedTags(List words, List tags, int index, MaxentModel posModel = modelPackage.getArtifact(POSModel.POS_MODEL_ENTRY_NAME); if (posModel != null) { - double[] probs = posModel.eval(contextGen.getContext(index, words.toArray(new String[0]), + double[] probs = posModel.eval(cg.getContext(index, words.toArray(new String[0]), tags.toArray(new String[0]), null)); String[] orderedTags = new String[probs.length]; @@ -263,34 +283,44 @@ public String[] getOrderedTags(List words, List tags, int index, } } - public static POSModel train(String languageCode, - ObjectStream samples, TrainingParameters trainParams, - POSTaggerFactory posFactory) throws IOException { - - int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE); - - POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator(); + /** + * Starts a training of a {@link POSModel} with the given parameters. + * + * @param languageCode The ISO language code to train the model. Must not be {@code null}. + * @param samples The {@link ObjectStream} of {@link POSSample} used as input for training. + * @param mlParams The {@link TrainingParameters} for the context of the training process. + * @param posFactory The {@link POSTaggerFactory} for creating related objects as defined + * via {@code mlParams}. + * + * @return A valid, trained {@link POSModel} instance. + * @throws IOException Thrown if IO errors occurred. + */ + public static POSModel train(String languageCode, ObjectStream samples, + TrainingParameters mlParams, POSTaggerFactory posFactory) + throws IOException { - Map manifestInfoEntries = new HashMap<>(); + final int beamSize = mlParams.getIntParameter( + BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE); - TrainerType trainerType = TrainerFactory.getTrainerType(trainParams); + final POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator(); + final TrainerType trainerType = TrainerFactory.getTrainerType(mlParams); + final Map manifestInfoEntries = new HashMap<>(); MaxentModel posModel = null; SequenceClassificationModel seqPosModel = null; if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) { ObjectStream es = new POSSampleEventStream(samples, contextGenerator); - EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, - manifestInfoEntries); + EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries); posModel = trainer.train(es); } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) { POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); EventModelSequenceTrainer trainer = - TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries); + TrainerFactory.getEventModelSequenceTrainer(mlParams, manifestInfoEntries); posModel = trainer.train(ss); } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer( - trainParams, manifestInfoEntries); + mlParams, manifestInfoEntries); // TODO: This will probably cause issue, since the feature generator uses the outcomes array