Skip to content

Commit eba946c

Browse files
committed
Multiple speed improvements
* Classification: * Significantly faster growing (two orders of magnitude for sets with >100k instances), faster pruning, * Added approximate mode (`approximate_induction` parameter). Regression: * Mean-based growing set as default (few times faster then median, non-significant impact on accuracy). Survival: * Faster growing and pruning (few fold improvement).
1 parent 9720962 commit eba946c

27 files changed

+744
-417
lines changed

adaa.analytics.rules/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ codeQuality {
2727
}
2828

2929
sourceCompatibility = 1.8
30-
version = '1.6.2'
30+
version = '1.7.0'
3131

3232

3333
jar {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/consoles/ExperimentalConsole.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ private void parse(String[] args) {
150150
RapidMiner.setExecutionMode(RapidMiner.ExecutionMode.COMMAND_LINE);
151151

152152
RapidMiner.init();
153-
//System.in.read();
153+
// System.in.read();
154154
execute(argList.get(0));
155155

156156
} else {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/AbstractFinder.java

+41-19
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ public void close() {
8181
/**
8282
* Can be implemented by subclasses to perform some initial processing prior growing.
8383
* @param trainSet Training set.
84+
* @return Preprocessed training set.
8485
*/
85-
public void preprocess(ExampleSet trainSet) {}
86+
public ExampleSet preprocess(ExampleSet trainSet) { return trainSet; }
8687

8788
/**
8889
* Adds elementary conditions to the rule premise until termination conditions are fulfilled.
@@ -104,11 +105,14 @@ public int grow(
104105
int initialConditionsCount = rule.getPremise().getSubconditions().size();
105106

106107
// get current covering
107-
Covering covering = new Covering();
108-
rule.covers(dataset, covering, covering.positives, covering.negatives);
109-
Set<Integer> covered = new HashSet<Integer>();
110-
covered.addAll(covering.positives);
111-
covered.addAll(covering.negatives);
108+
ContingencyTable contingencyTable = new Covering();
109+
IntegerBitSet positives = new IntegerBitSet(dataset.size());
110+
IntegerBitSet negatives = new IntegerBitSet(dataset.size());
111+
rule.covers(dataset, contingencyTable, positives, negatives);
112+
//Set<Integer> covered = new HashSet<Integer>();
113+
IntegerBitSet covered = new IntegerBitSet(dataset.size());
114+
covered.addAll(positives);
115+
covered.addAll(negatives);
112116
Set<Attribute> allowedAttributes = new TreeSet<Attribute>(new AttributeComparator());
113117
for (Attribute a: dataset.getAttributes()) {
114118
allowedAttributes.add(a);
@@ -126,18 +130,23 @@ public int grow(
126130

127131
notifyConditionAdded(condition);
128132

129-
covering = new Covering();
130-
rule.covers(dataset, covering, covering.positives, covering.negatives);
131-
covered.clear();
132-
covered.addAll(covering.positives);
133-
covered.addAll(covering.negatives);
133+
//recalculate covering only when needed
134+
if (condition.getCovering() != null) {
135+
positives.retainAll(condition.getCovering());
136+
negatives.retainAll(condition.getCovering());
137+
covered.retainAll(condition.getCovering());
138+
} else {
139+
contingencyTable.clear();
140+
positives.clear();
141+
negatives.clear();
142+
143+
rule.covers(dataset, contingencyTable, positives, negatives);
144+
covered.clear();
145+
covered.addAll(positives);
146+
covered.addAll(negatives);
147+
}
134148

135-
rule.setCoveringInformation(covering);
136-
rule.getCoveredPositives().setAll(covering.positives);
137-
rule.getCoveredNegatives().setAll(covering.negatives);
138149

139-
rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure());
140-
141150
Logger.log("Condition " + rule.getPremise().getSubconditions().size() + " added: "
142151
+ rule.toString() + ", weight=" + rule.getWeight() + "\n", Level.FINER);
143152

@@ -152,12 +161,25 @@ public int grow(
152161
carryOn = false;
153162
}
154163

155-
} while (carryOn);
156-
164+
} while (carryOn);
165+
166+
// ugly
167+
Covering covering = new Covering();
168+
covering.positives = positives;
169+
covering.negatives = negatives;
170+
171+
rule.setCoveringInformation(covering);
172+
rule.getCoveredPositives().setAll(positives);
173+
rule.getCoveredNegatives().setAll(negatives);
174+
157175
// if rule has been successfully grown
158176
int addedConditionsCount = rule.getPremise().getSubconditions().size() - initialConditionsCount;
159-
rule.setInducedContitionsCount(addedConditionsCount);
160177

178+
if (addedConditionsCount > 0) {
179+
rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure());
180+
}
181+
182+
rule.setInducedContitionsCount(addedConditionsCount);
161183
notifyGrowingFinished(rule);
162184

163185
return addedConditionsCount;

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ActionFinder.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ public ActionFinder(ActionInductionParameters params) {
3030
classificationFinder = new ClassificationFinder(params);
3131
}
3232

33-
public void preprocess(ExampleSet trainSet) {
34-
classificationFinder.preprocess(trainSet);
33+
public ExampleSet preprocess(ExampleSet trainSet) {
34+
return classificationFinder.preprocess(trainSet);
3535
}
3636

3737
private void log(String msg, Level level) {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java

+67-29
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ public ConditionCandidate(String attribute, IValueSet valueSet) {
3232
}
3333
}
3434

35-
protected static final int MAX_BINS = 100;
36-
3735
// Example description:
3836
// [0-31] - example id (32 bits)
3937
// [32-47] - block id (16 bits)
@@ -73,27 +71,36 @@ public ApproximateClassificationFinder(InductionParameters params) {
7371
}
7472

7573
@Override
76-
public void preprocess(ExampleSet dataset) {
74+
public ExampleSet preprocess(ExampleSet dataset) {
7775
int n_examples = dataset.size();
7876
int n_attributes = dataset.getAttributes().size();
7977

8078
trainSet = dataset;
8179
descriptions = new long[n_attributes][n_examples];
8280
mappings = new int[n_attributes][n_examples];
8381

84-
bins_positives = new int[n_attributes][MAX_BINS];
85-
bins_negatives = new int[n_attributes][MAX_BINS];
86-
bins_newPositives = new int[n_attributes][MAX_BINS];
87-
bins_begins = new int[n_attributes][MAX_BINS];
82+
bins_positives = new int[n_attributes][];
83+
bins_negatives = new int[n_attributes][];
84+
bins_newPositives = new int[n_attributes][];
85+
bins_begins = new int[n_attributes][];
8886

8987
ruleRanges = new int[n_attributes][2];
9088

9189
for (Attribute attr: dataset.getAttributes()) {
9290
int ia = attr.getTableIndex();
91+
int n_vals = attr.isNominal() ? attr.getMapping().size() : params.getApproximateBinsCount();
92+
93+
bins_positives[ia] = new int [n_vals];
94+
bins_negatives[ia] = new int[n_vals];
95+
bins_newPositives[ia] = new int[n_vals];
96+
bins_begins[ia] = new int[n_vals];
9397

9498
determineBins(dataset, attr, descriptions[ia], mappings[ia], bins_begins[ia], ruleRanges[ia]);
99+
95100
arrayCopies.put("ruleRanges", (Object)Arrays.stream(ruleRanges).map(int[]::clone).toArray(int[][]::new));
96101
}
102+
103+
return dataset;
97104
}
98105

99106
/**
@@ -293,13 +300,14 @@ protected ElementaryCondition induceCondition(
293300
int covered_n = 0;
294301
int covered_new_p = 0;
295302

296-
// use first attribute to establish number of covered elements
303+
// use first attribute to establish number of covered elements
297304
for (int bid = ruleRanges[0][0]; bid < ruleRanges[0][1]; ++bid) {
298305
covered_p += bins_positives[0][bid];
299306
covered_n += bins_negatives[0][bid];
300307
covered_new_p += bins_newPositives[0][bid];
301308
}
302309

310+
303311
// iterate over all allowed decision attributes
304312
for (Attribute attr : dataset.getAttributes()) {
305313

@@ -462,7 +470,10 @@ class Stats {
462470

463471
if (current != null && current.getAttribute() != null) {
464472
Logger.log("\tAttribute best: " + current + ", quality=" + current.quality, Level.FINEST);
465-
updateMidpoint(dataset, current);
473+
Attribute attr = dataset.getAttributes().get(current.getAttribute());
474+
if (attr.isNumerical()) {
475+
updateMidpoint(dataset, current);
476+
}
466477
Logger.log(", adjusted: " + current + "\n", Level.FINEST);
467478
}
468479

@@ -482,13 +493,13 @@ class Stats {
482493
return null; // empty condition - discard
483494
}
484495

485-
updateMidpoint(dataset, best);
486-
487-
Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST);
488-
489-
if (bestAttr.isNominal()) {
496+
if (bestAttr.isNumerical()) {
497+
updateMidpoint(dataset, best);
498+
} else {
490499
allowedAttributes.remove(bestAttr);
491500
}
501+
502+
Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST);
492503
}
493504

494505
return best;
@@ -508,7 +519,7 @@ protected void notifyConditionAdded(ConditionBase cnd) {
508519
ruleRanges[aid][0] = blockId + 1;
509520
ruleRanges[aid][1] = blockId;
510521
} else {
511-
excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId + 1);
522+
excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId);
512523
excludeExamplesFromArrays(trainSet, attr, candidate.blockId + 1, ruleRanges[aid][1]);
513524
ruleRanges[aid][0] = blockId;
514525
ruleRanges[aid][1] = blockId + 1;
@@ -546,6 +557,7 @@ protected void determineBins(ExampleSet dataset, Attribute attr,
546557
vals[i] = dataset.getExample(i).getValue(attr);
547558
}
548559

560+
549561
/*
550562
class ValuesComparator implements IntComparator {
551563
double [] vals;
@@ -597,12 +609,12 @@ public int compare(Bin p, Bin q) {
597609
}
598610
}
599611

600-
PriorityQueue<Bin> bins = new PriorityQueue<Bin>(100, new SizeBinComparator());
601-
PriorityQueue<Bin> finalBins = new PriorityQueue<Bin>(100, new IndexBinComparator());
612+
PriorityQueue<Bin> bins = new PriorityQueue<Bin>(binsBegins.length, new SizeBinComparator());
613+
PriorityQueue<Bin> finalBins = new PriorityQueue<Bin>(binsBegins.length, new IndexBinComparator());
602614

603615
bins.add(new Bin(0, mappings.length));
604616

605-
while (bins.size() > 0 && (bins.size() + finalBins.size()) < MAX_BINS) {
617+
while (bins.size() > 0 && (bins.size() + finalBins.size()) < binsBegins.length) {
606618
Bin b = bins.poll();
607619

608620
int id = (b.end + b.begin) / 2;
@@ -611,9 +623,13 @@ public int compare(Bin p, Bin q) {
611623
// decide direction
612624
if (vals[b.begin] == midval) {
613625
// go up
614-
while (vals[id] == midval) { ++id; }
626+
while (vals[id] == midval) {
627+
++id;
628+
}
615629
} else {
616-
while (vals[id - 1] == midval) { --id; }
630+
while (vals[id - 1] == midval) {
631+
--id;
632+
}
617633
}
618634

619635
Bin leftBin = new Bin(b.begin, id);
@@ -646,17 +662,16 @@ public int compare(Bin p, Bin q) {
646662
descriptions[i] |= bid << OFFSET_BIN;
647663
}
648664

649-
binsBegins[(int)bid] = b.begin;
665+
binsBegins[(int) bid] = b.begin;
650666
++bid;
651667
}
652668

653669
ruleRanges[0] = 0;
654-
ruleRanges[1] = (int)bid;
655-
656-
// print bins
657-
for (int i = 0; i < bid; ++i) {
670+
ruleRanges[1] = (int) bid;
671+
// print bins
672+
for (int i = 0; i < ruleRanges[1]; ++i) {
658673
int lo = binsBegins[i];
659-
int hi = (i == bid - 1) ? trainSet.size() : binsBegins[i+1] - 1;
674+
int hi = (i == ruleRanges[1] - 1) ? trainSet.size() : binsBegins[i+1] - 1;
660675
Logger.log("[" + lo + ", " + hi + "]:" + vals[lo] + "\n", Level.FINER);
661676
}
662677
}
@@ -665,6 +680,10 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
665680

666681
Logger.log("Excluding examples: " + attr.getName() + " from [" + binLo + "," + binHi + "]\n", Level.FINER);
667682

683+
if (binLo == binHi) {
684+
return;
685+
}
686+
668687
int n_examples = dataset.size();
669688
int src_row = attr.getTableIndex();
670689
long[] src_descriptions = descriptions[src_row];
@@ -695,9 +714,11 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
695714
int dst_row = other.getTableIndex();
696715

697716
// if nominal attribute was already used
717+
/*
698718
if (other.isNominal() && Math.abs(ruleRanges[dst_row][1] - ruleRanges[dst_row][0]) == 1) {
699719
continue;
700720
}
721+
*/
701722

702723
Future<Object> future = pool.submit(() -> {
703724

@@ -717,8 +738,14 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
717738

718739
int bid = (int) ((desc & MASK_BIN) >> OFFSET_BIN);
719740

741+
boolean opposite = dst_ranges[0] > dst_ranges[1]; // this indicate nominal opposite condition
742+
int dst_bin_lo = Math.min(dst_ranges[0], dst_ranges[1]);
743+
int dst_bin_hi = Math.max(dst_ranges[0], dst_ranges[1]);
744+
720745
// update stats only in bins covered by the rule
721-
if (bid >= dst_ranges[0] && bid < dst_ranges[1] && ((desc & FLAG_COVERED) != 0)) {
746+
boolean in_range = (bid >= dst_bin_lo && bid < dst_bin_hi) || (opposite && (bid < dst_bin_lo || bid >= dst_bin_hi));
747+
748+
if (in_range && ((desc & FLAG_COVERED) != 0)) {
722749

723750
if ((desc & FLAG_POSITIVE) != 0) {
724751
--dst_positives[bid];
@@ -755,12 +782,16 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {
755782

756783
int n_examples = dataset.size();
757784

785+
int[][] copy_ranges = (int[][])arrayCopies.get("ruleRanges");
786+
758787
for (Attribute attr: dataset.getAttributes()) {
759788
int attribute_id = attr.getTableIndex();
760789

761790
Arrays.fill(bins_positives[attribute_id], 0);
762791
Arrays.fill(bins_negatives[attribute_id], 0);
763792
Arrays.fill(bins_newPositives[attribute_id], 0);
793+
ruleRanges[attribute_id][0] = 0;
794+
ruleRanges[attribute_id][1] = copy_ranges[attribute_id][1];
764795

765796
long[] descriptions_row = descriptions[attribute_id];
766797
int[] mappings_row = mappings[attribute_id];
@@ -792,6 +823,9 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {
792823
}
793824
}
794825

826+
// reset rule ranges
827+
828+
795829
Logger.log("Reset arrays for class " + targetLabel + "\n", Level.FINER);
796830
printArrays();
797831

@@ -816,9 +850,13 @@ protected void printArrays() {
816850

817851
int bin_p = 0, bin_n = 0, bin_new_p = 0, bin_outside = 0;
818852

819-
for (int i = 0; i < MAX_BINS; ++i) {
853+
boolean opposite = ruleRanges[attribute_id][0] > ruleRanges[attribute_id][1]; // this indicate nominal opposite condition
854+
int lo = Math.min(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]);
855+
int hi = Math.max(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]);
856+
857+
for (int i = 0; i < bins_positives[attribute_id].length; ++i) {
820858

821-
if (i >= ruleRanges[attribute_id][0] && i < ruleRanges[attribute_id][1]) {
859+
if ((i >= lo && i < hi) || (opposite && (i < lo || i >= hi)) ) {
822860
bin_p += bins_positives[attribute_id][i];
823861
bin_n += bins_negatives[attribute_id][i];
824862
bin_new_p += bins_newPositives[attribute_id][i];

0 commit comments

Comments
 (0)