Skip to content

Commit

Permalink
Reset timepoint and buffer in finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
kleag committed May 12, 2024
1 parent c45653d commit a7786ff
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions deeplima/include/deeplima/token_sequence_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer

inline void reset(size_t position = 0)
{
std::cerr << "TokenSequenceAnalyzer::reset" << std::endl;
m_current = position;
}

Expand Down Expand Up @@ -234,19 +235,21 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer
// enriched_token_buffer_t,
// typename enriched_token_buffer_t::token_t> FeaturesVectorizer;

typedef tagging::impl::EntityTaggingClassifier<TaggingAuxScalar> Classifier;
// typedef tagging::impl::EntityTaggingClassifier<TaggingAuxScalar> Classifier;

typedef tagging::impl::FeaturesVectorizerWithPrecomputing<
Classifier,
enriched_token_buffer_t,
typename enriched_token_buffer_t::token_t> FeaturesVectorizer;
// typedef tagging::impl::FeaturesVectorizerWithPrecomputing<
// Classifier,
// enriched_token_buffer_t,
// typename enriched_token_buffer_t::token_t> FeaturesVectorizer;

// typedef tagging::impl::TaggingImpl< Classifier,
// FeaturesVectorizer,
// Matrix > EntityTaggingModule;

typedef DictEmbdVectorizer<EmbdUInt64FloatHolder, EmbdUInt64Float, eigen_wrp::EigenMatrixXf> EmbdVectorizer;
// typedef lemmatization::impl::LemmatizationImpl< RnnSeq2Seq, EmbdVectorizer, Matrix> LemmatizationModule;
// typedef DictEmbdVectorizer<EmbdUInt64FloatHolder, EmbdUInt64Float,
// eigen_wrp::EigenMatrixXf> EmbdVectorizer;
// typedef lemmatization::impl::LemmatizationImpl< RnnSeq2Seq, EmbdVectorizer,
// Matrix> LemmatizationModule;

public:
TokenSequenceAnalyzer() :
Expand Down Expand Up @@ -275,11 +278,11 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer
m_stridx(*m_stridx_ptr),
m_classes(std::make_shared<StdMatrix<uint8_t>>())
{
// std::cerr << "TokenSequenceAnalyzer::TokenSequenceAnalyzer " << model_fn << ", "
// << lemm_model_fn << ", " << lemm_dict_fn << ", "
// << fixed_ini_fn << ", " << lower_ini_fn << ", "
// << fixed_lemm_fn
// << std::endl;
std::cerr << "TokenSequenceAnalyzer::TokenSequenceAnalyzer " << model_fn << ", "
<< lemm_model_fn << ", " << lemm_dict_fn << ", "
<< fixed_ini_fn << ", " << lower_ini_fn << ", "
<< fixed_lemm_fn
<< std::endl;
assert(m_buffer_size > 0);
assert(num_buffers > 0);
m_buffers.resize(num_buffers);
Expand Down Expand Up @@ -321,7 +324,7 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer
m_cls.register_handler([this](
std::shared_ptr< StdMatrix<uint8_t> > classes,
size_t begin, size_t end, size_t slot_idx){
// std::cerr << "handler called: " << slot_idx << std::endl;
std::cerr << "handler called: " << slot_idx << std::endl;

lemmatize(m_buffers[slot_idx], m_lemm_buffers[slot_idx], classes, begin, end);

Expand All @@ -341,7 +344,7 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer
std::shared_ptr< StdMatrix<uint8_t> > classes,
size_t begin, size_t end, size_t slot_idx)
{
// std::cerr << "handler called: " << slot_idx << std::endl;
std::cerr << "handler called: " << slot_idx << std::endl;
m_classes = classes;
m_output_callback(m_stridx_ptr,
m_buffers[slot_idx],
Expand Down Expand Up @@ -425,6 +428,8 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer
}

m_cls.send_all_results();
m_current_timepoint = 0;
m_current_buffer = 0;
}

virtual void operator()(const std::vector<deeplima::segmentation::token_pos>& tokens, uint32_t len) override
Expand Down Expand Up @@ -463,6 +468,7 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer

void acquire_buffer()
{
std::cerr << "acquire_buffer" << std::endl;
size_t next_buffer_idx = (m_current_buffer + 1 < m_buffers.size()) ? (m_current_buffer + 1) : 0;
const token_buffer_t<>& next_buffer = m_buffers[next_buffer_idx];

Expand All @@ -479,7 +485,7 @@ class TokenSequenceAnalyzer : public ITokenSequenceAnalyzer

void start_analysis(size_t buffer_idx, int count = -1)
{
// std::cerr << "TokenSequenceAnalyzer::start_analysis " << buffer_idx << ", " << count << std::endl;
std::cerr << "TokenSequenceAnalyzer::start_analysis " << buffer_idx << ", " << count << std::endl;
assert(!m_buffers[buffer_idx].locked());
m_buffers[buffer_idx].lock();

Expand Down

0 comments on commit a7786ff

Please sign in to comment.