Skip to content

Commit a2de3da

Browse files
committed
Several fixes in handling nominal attributes in approximate induction.
1 parent 3e9ded6 commit a2de3da

File tree

3 files changed

+71
-53
lines changed

3 files changed

+71
-53
lines changed

adaa.analytics.rules/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ codeQuality {
2727
}
2828

2929
sourceCompatibility = 1.8
30-
version = '1.7.9'
30+
version = '1.7.10'
3131

3232

3333
jar {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java

+64-46
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
import adaa.analytics.rules.logic.representation.*;
44
import com.rapidminer.example.Attribute;
5-
import com.rapidminer.example.Example;
65
import com.rapidminer.example.ExampleSet;
7-
import org.apache.lucene.search.FieldComparator;
8-
import org.jetbrains.annotations.NotNull;
96

10-
import java.io.IOException;
117
import java.util.*;
128
import java.util.concurrent.ExecutionException;
139
import java.util.concurrent.Future;
@@ -21,7 +17,8 @@ public class ApproximateClassificationFinder extends ClassificationFinder {
2117
static class ConditionCandidate extends ElementaryCondition {
2218

2319
public double quality = -Double.MAX_VALUE;
24-
public double covered = 0;
20+
public double p = 0;
21+
public double n = 0;
2522
public boolean opposite = false;
2623
public int blockId = -1;
2724

@@ -98,8 +95,22 @@ public ExampleSet preprocess(ExampleSet dataset) {
9895
determineBins(dataset, attr, descriptions[ia], mappings[ia], bins_begins[ia], ruleRanges[ia]);
9996

10097
arrayCopies.put("ruleRanges", (Object)Arrays.stream(ruleRanges).map(int[]::clone).toArray(int[][]::new));
98+
99+
if (attr.isNominal()) {
100+
// get orders
101+
Integer[] valuesOrder = new Integer[attr.getMapping().size()];
102+
List<String> labels = new ArrayList<>();
103+
labels.addAll(attr.getMapping().getValues());
104+
Collections.sort(labels);
105+
for (int j = 0; j < labels.size(); ++j) {
106+
valuesOrder[j] = attr.getMapping().getIndex(labels.get(j));
107+
}
108+
attributeValuesOrder.put(attr, valuesOrder);
109+
}
101110
}
102111

112+
113+
103114
return dataset;
104115
}
105116

@@ -274,11 +285,6 @@ protected ElementaryCondition induceCondition(
274285
Set<Attribute> allowedAttributes,
275286
Object... extraParams) {
276287

277-
278-
if (rule.getPremise().getSubconditions().size() == 41) {
279-
//return null;
280-
}
281-
282288
if (allowedAttributes.size() == 0) {
283289
return null;
284290
}
@@ -342,26 +348,26 @@ class Stats {
342348
}
343349
}
344350

345-
int first_bid = ruleRanges[attribute_id][0];
346-
int last_bid = ruleRanges[attribute_id][1];
347-
348-
// omit empty bins from the beginning and from the end
349-
while (first_bid < last_bid && (cur_positives[first_bid] + cur_negatives[first_bid] == 0)) {
350-
++first_bid;
351-
}
352-
353-
while (first_bid < last_bid && (cur_positives[last_bid - 1] + cur_negatives[last_bid - 1] == 0)) {
354-
--last_bid;
355-
}
356-
357351
Stats[] stats = new Stats[2];
358-
stats[0] = new Stats(cur_positives[first_bid], cur_negatives[first_bid], cur_newPositives[first_bid]);
359-
stats[1] = new Stats(finalCovered_p - stats[0].p, finalCovered_n - stats[0].n, finalCovered_new_p - stats[0].p_new);
360352

361353
// numerical attribute
362354
if (attr.isNumerical()) {
363-
// iterate over blocks
355+
int first_bid = ruleRanges[attribute_id][0];
356+
int last_bid = ruleRanges[attribute_id][1];
357+
358+
// omit empty bins from the beginning and from the end
359+
while (first_bid < last_bid && (cur_positives[first_bid] + cur_negatives[first_bid] == 0)) {
360+
++first_bid;
361+
}
364362

363+
while (first_bid < last_bid && (cur_positives[last_bid - 1] + cur_negatives[last_bid - 1] == 0)) {
364+
--last_bid;
365+
}
366+
367+
stats[0] = new Stats(cur_positives[first_bid], cur_negatives[first_bid], cur_newPositives[first_bid]);
368+
stats[1] = new Stats(finalCovered_p - stats[0].p, finalCovered_n - stats[0].n, finalCovered_new_p - stats[0].p_new);
369+
370+
// iterate over blocks
365371
for (int bid = first_bid + 1; bid < last_bid; ++bid) {
366372
// omit conditions:
367373
// - preceding empty bins - they may appear as coverage drops
@@ -380,8 +386,8 @@ class Stats {
380386
if (prec > apriori_prec && stats[c].p_new > 0) {
381387
double quality = params.getInductionMeasure().calculate(stats[c].p, stats[c].n, P, N);
382388

383-
// better then current best
384-
if (quality > best.quality || (quality == best.quality && stats[c].p > best.covered)) {
389+
// better than current best
390+
if (quality > best.quality || (quality == best.quality && stats[c].p > best.p)) {
385391

386392
int left_id = (int) (cur_descriptions[cur_begins[bid] - 1] & MASK_IDENTIFIER);
387393
int right_id = (int) (cur_descriptions[cur_begins[bid]] & MASK_IDENTIFIER);
@@ -397,7 +403,8 @@ class Stats {
397403
//Logger.log("\tCurrent best: " + candidate + " (p=" + stats[c].p + ", n=" + stats[c].n + ", new_p=" + (double) stats[c].p_new + ", quality=" + quality + ")\n", Level.FINEST);
398404
best = candidate;
399405
best.quality = quality;
400-
best.covered = stats[c].p;
406+
best.p = stats[c].p;
407+
best.n = stats[c].n;
401408
best.opposite = (c == 1);
402409
best.blockId = bid;
403410
}
@@ -417,17 +424,36 @@ class Stats {
417424
}
418425
} else { // nominal attribute
419426

420-
for (int bid = 1; bid < cur_positives.length; ++bid) {
427+
// they will be reassigned anyway
428+
stats[0] = new Stats(0, 0, 0);
429+
stats[1] = new Stats(finalCovered_p - stats[0].p, finalCovered_n - stats[0].n, finalCovered_new_p - stats[0].p_new);
430+
431+
for (int j = 0; j < attr.getMapping().size(); ++j) {
432+
int bid = attributeValuesOrder.get(attr)[j];
433+
434+
// update stats
435+
stats[0].p = cur_positives[bid];
436+
stats[0].n = cur_negatives[bid];
437+
stats[0].p_new = cur_newPositives[bid];
438+
439+
stats[1].p = finalCovered_p - stats[0].p;
440+
stats[1].n = finalCovered_n - stats[0].n;
441+
stats[1].p_new = finalCovered_new_p - stats[0].p_new;
442+
421443
// evaluate both conditions
422444
for (int c = 0; c < 2; ++c) {
423445
double prec = stats[c].p / (stats[c].p + stats[c].n);
424446

425447
if (prec > apriori_prec && stats[c].p_new > 0) {
426448
double quality = params.getInductionMeasure().calculate(stats[c].p, stats[c].n, P, N);
427449

428-
// better then current best
429-
if (quality > best.quality || (quality == best.quality && stats[c].p > best.covered)) {
430-
IValueSet interval = (c == 0)
450+
boolean opposite = (c == 1);
451+
452+
// better than current best
453+
if (quality > best.quality || (quality == best.quality && (stats[c].p > best.p ||
454+
(stats[c].p == best.p && best.opposite && !opposite)))) {
455+
456+
IValueSet interval = !opposite
431457
? new SingletonSet((double) bid, attr.getMapping().getValues())
432458
: new SingletonSetComplement((double) bid, attr.getMapping().getValues());
433459

@@ -436,23 +462,14 @@ class Stats {
436462
//Logger.log("\tCurrent best: " + candidate + " (p=" + stats[c].p + ", n=" + stats[c].n + ", new_p=" + (double) stats[c].p_new + ", quality=" + quality + ")\n", Level.FINEST);
437463
best = candidate;
438464
best.quality = quality;
439-
best.covered = stats[c].p;
440-
best.opposite = (c == 1);
465+
best.p = stats[c].p;
466+
best.n = stats[c].n;
467+
best.opposite = opposite;
441468
best.blockId = bid;
442469
}
443470
}
444471
}
445472
}
446-
447-
// update stats
448-
stats[0].p = cur_positives[bid];
449-
stats[0].n = cur_negatives[bid];
450-
stats[0].p_new = cur_newPositives[bid];
451-
452-
stats[1].p = finalCovered_p - stats[0].p;
453-
stats[1].n = finalCovered_n - stats[0].n;
454-
stats[1].p_new = finalCovered_new_p - stats[0].p_new;
455-
456473
}
457474
}
458475

@@ -469,15 +486,16 @@ class Stats {
469486
ConditionCandidate current = (ConditionCandidate)f.get();
470487

471488
if (current != null && current.getAttribute() != null) {
472-
Logger.log("\tAttribute best: " + current + ", quality=" + current.quality, Level.FINEST);
489+
Logger.log("\tAttribute best: " + current + ", quality=" +
490+
current.quality + ", p=" + current.p + ", n=" + current.n, Level.FINEST);
473491
Attribute attr = dataset.getAttributes().get(current.getAttribute());
474492
if (attr.isNumerical()) {
475493
updateMidpoint(dataset, current);
476494
}
477495
Logger.log(", adjusted: " + current + "\n", Level.FINEST);
478496
}
479497

480-
if (best == null || current.quality > best.quality || (current.quality == best.quality && current.covered > best.covered)) {
498+
if (best == null || current.quality > best.quality || (current.quality == best.quality && current.p > best.p)) {
481499
best = current;
482500
}
483501
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationFinder.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ public class ClassificationFinder extends AbstractFinder {
4040
* Map of precalculated coverings (time optimization).
4141
* For each attribute there is a set of distinctive values. For each value there is a bit vector of examples covered.
4242
*/
43-
protected Map<Attribute, Map<Double, IntegerBitSet>> precalculatedCoverings;
43+
protected Map<Attribute, Map<Double, IntegerBitSet>> precalculatedCoverings
44+
= new HashMap<Attribute, Map<Double, IntegerBitSet>>();
4445

45-
protected Map<Attribute, Map<Double, IntegerBitSet>> precalculatedCoveringsComplement;
46+
protected Map<Attribute, Map<Double, IntegerBitSet>> precalculatedCoveringsComplement
47+
= new HashMap<Attribute, Map<Double, IntegerBitSet>>();
4648

47-
protected Map<Attribute, Integer[]> attributeValuesOrder;
49+
protected Map<Attribute, Integer[]> attributeValuesOrder
50+
= new HashMap<Attribute, Integer[]>();
4851

4952
/**
5053
* Initializes induction parameters.
@@ -68,9 +71,6 @@ public ExampleSet preprocess(ExampleSet trainSet) {
6871
return trainSet;
6972
}
7073

71-
attributeValuesOrder = new HashMap<Attribute, Integer[]>();
72-
precalculatedCoverings = new HashMap<Attribute, Map<Double, IntegerBitSet>>();
73-
precalculatedCoveringsComplement = new HashMap<Attribute, Map<Double, IntegerBitSet>>();
7474
Attributes attributes = trainSet.getAttributes();
7575

7676
List<Future> futures = new ArrayList<Future>();

0 commit comments

Comments
 (0)