Skip to content

Commit

Permalink
OPENNLP-1677: Extend JavaDoc of POSTaggerME (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
mawiesne authored Dec 21, 2024
1 parent 2f2f631 commit 49678c3
Showing 1 changed file with 52 additions and 22 deletions.
74 changes: 52 additions & 22 deletions opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,30 @@
import opennlp.tools.util.featuregen.StringPattern;

/**
* A {@link POSTagger part-of-speech tagger} that uses maximum entropy.
* A {@link POSTagger part-of-speech tagger} implementation that uses maximum entropy.
* <p>
* Tries to predict whether words are nouns, verbs, or any of 70 other POS tags
* Tries to predict whether words are nouns, verbs, or any other {@link POSTagFormat POS tags}
* depending on their surrounding context.
*
* @see POSModel
* @see POSTagFormat
* @see POSTagger
*/
public class POSTaggerME implements POSTagger {

private static final Logger logger = LoggerFactory.getLogger(POSTaggerME.class);

/**
* The default beam size value is 3.
*/
public static final int DEFAULT_BEAM_SIZE = 3;

private final POSModel modelPackage;

/**
* The {@link POSContextGenerator feature context generator}.
*/
protected final POSContextGenerator contextGen;
protected final POSContextGenerator cg;

/**
* {@link TagDictionary} used for restricting words to a fixed set of tags.
Expand Down Expand Up @@ -140,7 +147,7 @@ public POSTaggerME(POSModel model, POSTagFormat format) {

modelPackage = model;

contextGen = factory.getPOSContextGenerator(beamSize);
cg = factory.getPOSContextGenerator(beamSize);
tagDictionary = factory.getTagDictionary();
size = beamSize;

Expand All @@ -165,14 +172,20 @@ public String[] getAllPosTags() {
return model.getOutcomes();
}

/**
* {@inheritDoc}
*/
@Override
public String[] tag(String[] sentence) {
return this.tag(sentence, null);
}

/**
* {@inheritDoc}
*/
@Override
public String[] tag(String[] sentence, Object[] additionalContext) {
bestSequence = model.bestSequence(sentence, additionalContext, contextGen, sequenceValidator);
bestSequence = model.bestSequence(sentence, additionalContext, cg, sequenceValidator);
final List<String> t = bestSequence.getOutcomes();
return convertTags(t);
}
Expand All @@ -186,7 +199,7 @@ public String[] tag(String[] sentence, Object[] additionalContext) {
*/
public String[][] tag(int numTaggings, String[] sentence) {
Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null,
contextGen, sequenceValidator);
cg, sequenceValidator);
String[][] tags = new String[bestSequences.length][];
for (int si = 0; si < tags.length; si++) {
List<String> t = bestSequences[si].getOutcomes();
Expand All @@ -204,18 +217,25 @@ private String[] convertTags(List<String> t) {
}
}

/**
* {@inheritDoc}
*/
@Override
public Sequence[] topKSequences(String[] sentence) {
return this.topKSequences(sentence, null);
}

/**
* {@inheritDoc}
*/
@Override
public Sequence[] topKSequences(String[] sentence, Object[] additionalContext) {
return model.bestSequences(size, sentence, additionalContext, contextGen, sequenceValidator);
return model.bestSequences(size, sentence, additionalContext, cg, sequenceValidator);
}

/**
* Populates the specified array with the probabilities for each tag of the last tagged sentence.
* Populates the specified {@code probs} array with the probabilities
* for each tag of the last tagged sentence.
*
* @param probs An array to put the probabilities into.
*/
Expand All @@ -239,7 +259,7 @@ public String[] getOrderedTags(List<String> words, List<String> tags, int index,
MaxentModel posModel = modelPackage.getArtifact(POSModel.POS_MODEL_ENTRY_NAME);
if (posModel != null) {

double[] probs = posModel.eval(contextGen.getContext(index, words.toArray(new String[0]),
double[] probs = posModel.eval(cg.getContext(index, words.toArray(new String[0]),
tags.toArray(new String[0]), null));

String[] orderedTags = new String[probs.length];
Expand All @@ -263,34 +283,44 @@ public String[] getOrderedTags(List<String> words, List<String> tags, int index,
}
}

public static POSModel train(String languageCode,
ObjectStream<POSSample> samples, TrainingParameters trainParams,
POSTaggerFactory posFactory) throws IOException {

int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);

POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
/**
* Starts a training of a {@link POSModel} with the given parameters.
*
* @param languageCode The ISO language code to train the model. Must not be {@code null}.
* @param samples The {@link ObjectStream} of {@link POSSample} used as input for training.
* @param mlParams The {@link TrainingParameters} for the context of the training process.
* @param posFactory The {@link POSTaggerFactory} for creating related objects as defined
* via {@code mlParams}.
*
* @return A valid, trained {@link POSModel} instance.
* @throws IOException Thrown if IO errors occurred.
*/
public static POSModel train(String languageCode, ObjectStream<POSSample> samples,
TrainingParameters mlParams, POSTaggerFactory posFactory)
throws IOException {

Map<String, String> manifestInfoEntries = new HashMap<>();
final int beamSize = mlParams.getIntParameter(
BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);

TrainerType trainerType = TrainerFactory.getTrainerType(trainParams);
final POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
final TrainerType trainerType = TrainerFactory.getTrainerType(mlParams);
final Map<String, String> manifestInfoEntries = new HashMap<>();

MaxentModel posModel = null;
SequenceClassificationModel seqPosModel = null;
if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator);

EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams,
manifestInfoEntries);
EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries);
posModel = trainer.train(es);
} else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
EventModelSequenceTrainer<POSSample> trainer =
TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
TrainerFactory.getEventModelSequenceTrainer(mlParams, manifestInfoEntries);
posModel = trainer.train(ss);
} else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams, manifestInfoEntries);
mlParams, manifestInfoEntries);

// TODO: This will probably cause issue, since the feature generator uses the outcomes array

Expand Down

0 comments on commit 49678c3

Please sign in to comment.