diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index 34e42981..fe6c7921 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -229,7 +229,7 @@ Parameters::Parameters(): PARAM_EXTRACT_LINES(PARAM_EXTRACT_LINES_ID, "--extract-lines", "Extract N lines", "Extract n lines of each entry", typeid(int), (void *) &extractLines, "^[1-9]{1}[0-9]*$"), PARAM_COMP_OPERATOR(PARAM_COMP_OPERATOR_ID, "--comparison-operator", "Numerical comparison operator", "Filter by comparing each entry row numerically by using the le) less-than-equal, ge) greater-than-equal or e) equal operator", typeid(std::string), (void *) &compOperator, ""), PARAM_COMP_VALUE(PARAM_COMP_VALUE_ID, "--comparison-value", "Numerical comparison value", "Filter by comparing each entry to this value", typeid(double), (void *) &compValue, "^.*$"), - PARAM_SORT_ENTRIES(PARAM_SORT_ENTRIES_ID, "--sort-entries", "Sort entries", "Sort column set by --filter-column, by 0: no sorting, 1: increasing, 2: decreasing, 3: random shuffle", typeid(int), (void *) &sortEntries, "^[1-9]{1}[0-9]*$"), + PARAM_SORT_ENTRIES(PARAM_SORT_ENTRIES_ID, "--sort-entries", "Sort entries", "Sort column set by --filter-column, by 0: no sorting, 1: increasing, 2: decreasing, 3: random shuffle, 4: priority", typeid(int), (void *) &sortEntries, "^[0-4]{1}$"), PARAM_BEATS_FIRST(PARAM_BEATS_FIRST_ID, "--beats-first", "Beats first", "Filter by comparing each entry to the first entry", typeid(bool), (void *) &beatsFirst, ""), PARAM_JOIN_DB(PARAM_JOIN_DB_ID, "--join-db", "join to DB", "Join another database entry with respect to the database identifier in the chosen column", typeid(std::string), (void *) &joinDB, ""), // besthitperset @@ -866,6 +866,7 @@ Parameters::Parameters(): filterDb.push_back(&PARAM_FILTER_FILE); filterDb.push_back(&PARAM_BEATS_FIRST); filterDb.push_back(&PARAM_MAPPING_FILE); + filterDb.push_back(&PARAM_WEIGHT_FILE); filterDb.push_back(&PARAM_TRIM_TO_ONE_COL); filterDb.push_back(&PARAM_EXTRACT_LINES); filterDb.push_back(&PARAM_COMP_OPERATOR); diff --git a/src/util/filterdb.cpp b/src/util/filterdb.cpp index 3ab883d6..b645d2db 100644 --- a/src/util/filterdb.cpp +++ b/src/util/filterdb.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -51,6 +52,8 @@ ComparisonOperator mapOperator(const std::string& op) { #define INCREASING 1 #define DECREASING 2 #define SHUFFLE 3 +#define PRIORITY 4 + struct compareString { bool operator() (const std::string& lhs, const std::string& rhs) const{ @@ -84,6 +87,8 @@ struct compareFirstEntryDecreasing { int filterdb(int argc, const char **argv, const Command &command) { Parameters &par = Parameters::getInstance(); + par.PARAM_WEIGHT_FILE.replaceCategory(MMseqsParameter::COMMAND_MISC); + par.parseParameters(argc, argv, command, true, 0, 0); const size_t column = static_cast(par.filterColumn); @@ -108,7 +113,7 @@ int filterdb(int argc, const char **argv, const Command &command) { // JOIN_DB DBReader* helper = NULL; - + std::unordered_map weights; // REGEX_FILTERING regex_t regex; std::random_device rng; @@ -117,6 +122,32 @@ int filterdb(int argc, const char **argv, const Command &command) { if (par.sortEntries != 0) { mode = SORT_ENTRIES; Debug(Debug::INFO) << "Filtering by sorting entries\n"; + if (par.sortEntries == PRIORITY) { + if (par.weightFile.empty()) { + Debug(Debug::ERROR) << "Weights file (--weights) must be specified for priority sorting.\n"; + EXIT(EXIT_FAILURE); + } + Debug(Debug::INFO) << "Sorting entries by priority\n"; + // Read the weights + std::ifstream weightsFile(par.weightFile); + if (!weightsFile) { + Debug(Debug::ERROR) << "Cannot open weights file " << par.weightFile << "\n"; + EXIT(EXIT_FAILURE); + } + + std::string line; + while (std::getline(weightsFile, line)) { + std::istringstream iss(line); + unsigned int key; + float weight; + if (!(iss >> key >> weight)) { + Debug(Debug::WARNING) << "Invalid line in weights file: " << line << "\n"; + continue; + } + weights[key] = weight; + } + weightsFile.close(); + } } else if (par.filteringFile.empty() == false) { mode = FILE_FILTERING; Debug(Debug::INFO) << "Filtering using file(s)\n"; @@ -453,8 +484,19 @@ int filterdb(int argc, const char **argv, const Command &command) { memcpy(lineBuffer, newLineBuffer, newLineBufferIndex + 1); } } else if (mode == SORT_ENTRIES) { - toSort.emplace_back(std::strtod(columnValue, NULL), lineBuffer); - // do not put anything in the output buffer + if (par.sortEntries == PRIORITY) { + unsigned int key = static_cast(strtoul(columnPointer[column - 1], NULL, 10)); + float weight = 0.0f; + auto it = weights.find(key); + if (it != weights.end()) { + weight = it->second; + } + toSort.emplace_back(weight, std::string(lineBuffer)); + } else { + // Existing code + toSort.emplace_back(std::strtod(columnValue, NULL), lineBuffer); + } + // Do not put anything in the output buffer nomatch = 1; } else { // Unknown filtering mode, keep all entries @@ -482,7 +524,7 @@ int filterdb(int argc, const char **argv, const Command &command) { if (mode == SORT_ENTRIES) { if (par.sortEntries == INCREASING) { std::stable_sort(toSort.begin(), toSort.end(), compareFirstEntry()); - } else if (par.sortEntries == DECREASING) { + } else if (par.sortEntries == DECREASING || par.sortEntries == PRIORITY) { std::stable_sort(toSort.begin(), toSort.end(), compareFirstEntryDecreasing()); } else if (par.sortEntries == SHUFFLE) { std::shuffle(toSort.begin(), toSort.end(), urng);