Skip to content

Commit 4ec5e59

Browse files
committed
Survival contrast sets running without crash.
1 parent dfe0a91 commit 4ec5e59

12 files changed

+137
-21
lines changed

adaa.analytics.rules/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ codeQuality {
2727
}
2828

2929
sourceCompatibility = 1.8
30-
version = '1.7.15'
30+
version = '1.7.16'
3131

3232

3333
jar {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastClassificationFinder.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
import adaa.analytics.rules.logic.quality.IQualityMeasure;
44
import adaa.analytics.rules.logic.quality.NegativeControlledMeasure;
5-
import adaa.analytics.rules.logic.representation.ContrastRule;
6-
import adaa.analytics.rules.logic.representation.MultiSet;
7-
import adaa.analytics.rules.logic.representation.Rule;
8-
import adaa.analytics.rules.logic.representation.SingletonSet;
5+
import adaa.analytics.rules.logic.representation.*;
96
import com.rapidminer.example.Attribute;
107
import com.rapidminer.example.ExampleSet;
8+
import com.rapidminer.example.set.SortedExampleSet;
119
import com.rapidminer.example.table.NominalMapping;
1210

1311
import java.util.Map;
@@ -34,6 +32,11 @@ public ContrastClassificationFinder(InductionParameters params) {
3432
params.setPruningMeasure(new NegativeControlledMeasure(params.getPruningMeasure(), params.getMaxcovNegative()));
3533
}
3634

35+
public ExampleSet preprocess(ExampleSet trainSet) {
36+
super.preprocess(trainSet);
37+
return trainSet; // return original one
38+
}
39+
3740

3841
/**
3942
* Invokes grow method from the super class and verifies negative coverage requirement.

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public String getName() {
3232
@Override
3333
public double calculate(ExampleSet dataset, ContingencyTable ct) {
3434

35-
ContrastRegressionExampleSet cer = (dataset instanceof ContrastExampleSet) ? (ContrastRegressionExampleSet)dataset : null;
35+
ContrastRegressionExampleSet cer = (dataset instanceof ContrastRegressionExampleSet) ? (ContrastRegressionExampleSet)dataset : null;
3636
if (cer == null) {
37-
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastRegressionExampleSet instances");
37+
throw new InvalidParameterException("ContrastRegressionRuleSet supports only ContrastRegressionExampleSet instances");
3838
}
3939

4040
Covering cov = (Covering)ct;
@@ -78,6 +78,11 @@ public ContrastRegressionFinder(InductionParameters params) {
7878
params.setMeanBasedRegression(false);
7979
}
8080

81+
public ExampleSet preprocess(ExampleSet trainSet) {
82+
super.preprocess(trainSet);
83+
return trainSet; // return original one
84+
}
85+
8186
/**
8287
* Invokes grow method from the super class and verifies negative coverage requirement.
8388
*

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastSnC.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public ContrastSnC(AbstractFinder finder, InductionParameters params) {
4040
public RuleSetBase run(ExampleSet dataset) {
4141

4242
// make a contrast dataset
43-
ContrastExampleSet ces;
43+
IContrastExampleSet ces;
4444

4545
if (factory.getType() == RuleFactory.CONTRAST_REGRESSION) {
4646
ces = new ContrastRegressionExampleSet((SimpleExampleSet) dataset);
@@ -84,7 +84,7 @@ public RuleSetBase run(ExampleSet dataset) {
8484
* @param dataset Training data set.
8585
* @return Rule set.
8686
*/
87-
protected void run(ContrastExampleSet dataset, ContrastRuleSet crs) {
87+
protected void run(IContrastExampleSet dataset, ContrastRuleSet crs) {
8888
Logger.log("ContrastSnC.run()\n", Level.FINE);
8989

9090
// try to get contrast attribute (use label if not specified)

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastSurvivalFinder.java

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.security.InvalidParameterException;
1212
import java.util.HashSet;
1313
import java.util.Set;
14+
import java.util.logging.Level;
1415

1516
public class ContrastSurvivalFinder extends SurvivalLogRankFinder implements IPenalizedFinder {
1617

@@ -27,7 +28,7 @@ public String getName() {
2728
@Override
2829
public double calculate(ExampleSet dataset, ContingencyTable ct) {
2930

30-
ContrastSurvivalExampleSet ces = (dataset instanceof ContrastExampleSet) ? (ContrastSurvivalExampleSet)dataset : null;
31+
ContrastSurvivalExampleSet ces = (dataset instanceof ContrastSurvivalExampleSet) ? (ContrastSurvivalExampleSet)dataset : null;
3132
if (ces == null) {
3233
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastSurvivalExampleSet instances");
3334
}
@@ -71,6 +72,11 @@ public ContrastSurvivalFinder(InductionParameters params) {
7172
params.setVotingMeasure(m);
7273
}
7374

75+
public ExampleSet preprocess(ExampleSet trainSet) {
76+
super.preprocess(trainSet);
77+
return trainSet; // return original one
78+
}
79+
7480
/**
7581
* Invokes grow method from the super class and verifies negative coverage requirement.
7682
*
@@ -123,6 +129,79 @@ public void postprocess(
123129
notifyRuleReady(rule);
124130
}
125131

132+
protected boolean checkCandidate(
133+
ExampleSet dataset,
134+
Rule rule,
135+
ConditionBase candidate,
136+
Set<Integer> uncovered,
137+
Set<Integer> covered,
138+
ConditionEvaluation currentBest) {
139+
140+
try {
141+
142+
CompoundCondition newPremise = new CompoundCondition();
143+
newPremise.getSubconditions().addAll(rule.getPremise().getSubconditions());
144+
newPremise.addSubcondition(candidate);
145+
146+
Rule newRule = (Rule) rule.clone();
147+
newRule.setPremise(newPremise);
148+
149+
150+
Covering cov = new Covering();
151+
newRule.covers(dataset, cov, cov.positives, cov.negatives);
152+
153+
double new_p = 0, new_n = 0;
154+
155+
if (dataset.getAttributes().getWeight() == null) {
156+
// unweighted examples
157+
new_p = SetHelper.intersectionSize(uncovered, cov.positives);
158+
new_n = SetHelper.intersectionSize(uncovered, cov.negatives);
159+
} else {
160+
// calculate weights of newly covered examples
161+
for (int id : cov.positives) {
162+
new_p += uncovered.contains(id) ? dataset.getExample(id).getWeight() : 0;
163+
}
164+
for (int id : cov.negatives) {
165+
new_n += uncovered.contains(id) ? dataset.getExample(id).getWeight() : 0;
166+
}
167+
}
168+
169+
if (checkCoverage(cov.weighted_p, cov.weighted_n, new_p, new_n, dataset.size(), 0, uncovered.size(), rule.getRuleOrderNum())) {
170+
171+
double quality = params.getInductionMeasure().calculate(dataset, cov);
172+
173+
if (candidate instanceof ElementaryCondition) {
174+
ElementaryCondition ec = (ElementaryCondition) candidate;
175+
quality = modifier.modifyQuality(quality, ec.getAttribute(), cov.weighted_p, new_p);
176+
}
177+
178+
if (quality > currentBest.quality ||
179+
(quality == currentBest.quality && (new_p > currentBest.covered || currentBest.opposite))) {
180+
181+
Logger.log("\t\tCurrent best: " + candidate + " (p=" + cov.weighted_p +
182+
", new_p=" + (double) new_p +
183+
", P=" + cov.weighted_P +
184+
", mean_y=" + cov.mean_y + ", mean_y2=" + cov.mean_y2 + ", stddev_y=" + cov.stddev_y +
185+
", quality=" + quality + "\n", Level.FINEST);
186+
187+
currentBest.quality = quality;
188+
currentBest.condition = candidate;
189+
currentBest.covered = new_p;
190+
currentBest.covering = cov;
191+
currentBest.opposite = (candidate instanceof ElementaryCondition) &&
192+
(((ElementaryCondition) candidate).getValueSet() instanceof SingletonSetComplement);
193+
194+
//rule.setWeight(quality);
195+
return true;
196+
}
197+
}
198+
199+
} catch (Exception e) {
200+
e.printStackTrace();
201+
}
202+
return false;
203+
}
204+
126205

127206
boolean checkCoverage(double p, double n, double new_p, double new_n, double P, double N) {
128207
return ((new_p) >= params.getAbsoluteMinimumCovered(P)) &&

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastExampleSet.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import java.util.ArrayList;
1212
import java.util.List;
1313

14-
public class ContrastExampleSet extends SimpleExampleSet {
14+
public class ContrastExampleSet extends SimpleExampleSet implements IContrastExampleSet {
1515

1616
protected Attribute contrastAttribute;
1717

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastRegressionExampleSet.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,30 @@
1212
import java.util.ArrayList;
1313
import java.util.List;
1414

15-
public class ContrastRegressionExampleSet extends ContrastExampleSet {
15+
public class ContrastRegressionExampleSet extends SortedExampleSetEx implements IContrastExampleSet {
16+
17+
protected Attribute contrastAttribute;
1618

1719
/** Training set estimator. */
1820
protected double trainingEstimator;
1921

2022
/** Collection of Kaplan-Meier estimators for contrast groups. */
2123
protected List<Double> groupEstimators = new ArrayList<Double>();
2224

25+
public Attribute getContrastAttribute() { return contrastAttribute; }
26+
2327
/** Gets {@link #groupEstimators} */
2428
public List<Double> getGroupEstimators() { return groupEstimators; }
2529

2630
/** Gets {@link #trainingEstimator}}. */
2731
public double getTrainingEstimator() { return trainingEstimator; }
2832

2933
public ContrastRegressionExampleSet(SimpleExampleSet exampleSet) {
30-
super(exampleSet);
34+
super(exampleSet, exampleSet.getAttributes().getLabel(), SortedExampleSetEx.INCREASING);
35+
36+
contrastAttribute = (exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
37+
? exampleSet.getAttributes().getLabel()
38+
: exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE);
3139

3240
String averageName = (exampleSet.getAttributes().getWeight() != null)
3341
? Statistics.AVERAGE_WEIGHTED : Statistics.AVERAGE;
@@ -57,6 +65,7 @@ public ContrastRegressionExampleSet(SimpleExampleSet exampleSet) {
5765

5866
public ContrastRegressionExampleSet(ContrastRegressionExampleSet rhs) {
5967
super(rhs);
68+
this.contrastAttribute = rhs.contrastAttribute;
6069
this.trainingEstimator = rhs.trainingEstimator;
6170
this.groupEstimators = rhs.groupEstimators;
6271
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastRegressionRuleSet.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class ContrastRegressionRuleSet extends ContrastRuleSet {
3939
public ContrastRegressionRuleSet(ExampleSet exampleSet, boolean isVoting, InductionParameters params, Knowledge knowledge) {
4040
super(exampleSet, isVoting, params, knowledge);
4141

42-
ContrastRegressionExampleSet cer = (exampleSet instanceof ContrastExampleSet) ? (ContrastRegressionExampleSet)exampleSet : null;
42+
ContrastRegressionExampleSet cer = (exampleSet instanceof ContrastRegressionExampleSet) ? (ContrastRegressionExampleSet)exampleSet : null;
4343
if (cer == null) {
4444
throw new InvalidParameterException("ContrastRegressionRuleSet supports only ContrastRegressionExampleSet instances");
4545
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastSurvivalExampleSet.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,46 @@
11
package adaa.analytics.rules.logic.representation;
22

3+
import com.rapidminer.example.Attribute;
34
import com.rapidminer.example.ExampleSet;
45
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
56
import com.rapidminer.example.set.ConditionedExampleSet;
67
import com.rapidminer.example.set.SimpleExampleSet;
8+
import com.rapidminer.example.set.SortedExampleSet;
79
import com.rapidminer.example.table.NominalMapping;
810
import com.rapidminer.operator.tools.ExpressionEvaluationException;
911

1012
import java.util.ArrayList;
1113
import java.util.List;
1214

13-
public class ContrastSurvivalExampleSet extends ContrastExampleSet {
15+
public class ContrastSurvivalExampleSet extends SortedExampleSetEx implements IContrastExampleSet {
16+
17+
protected Attribute contrastAttribute;
1418

1519
/** Training set estimator. */
1620
protected KaplanMeierEstimator trainingEstimator;
1721

1822
/** Collection of Kaplan-Meier estimators for contrast groups. */
1923
protected List<KaplanMeierEstimator> groupEstimators = new ArrayList<KaplanMeierEstimator>();
2024

25+
public Attribute getContrastAttribute() { return contrastAttribute; }
26+
2127
/** Gets {@link #groupEstimators} */
2228
public List<KaplanMeierEstimator> getGroupEstimators() { return groupEstimators; }
2329

2430
/** Gets {@link #trainingEstimator}}. */
2531
public KaplanMeierEstimator getTrainingEstimator() { return trainingEstimator; }
2632

2733
public ContrastSurvivalExampleSet(SimpleExampleSet exampleSet) {
28-
super(exampleSet);
34+
super(exampleSet, exampleSet.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE), SortedExampleSetEx.INCREASING);
35+
36+
contrastAttribute = (exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
37+
? exampleSet.getAttributes().getLabel()
38+
: exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE);
39+
40+
Attribute survTime = exampleSet.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE);
2941

3042
// establish training survival estimator
31-
trainingEstimator = new KaplanMeierEstimator(exampleSet);
43+
trainingEstimator = new KaplanMeierEstimator(this);
3244

3345
// establish contrast groups survival estimator
3446
try {
@@ -39,7 +51,9 @@ public ContrastSurvivalExampleSet(SimpleExampleSet exampleSet) {
3951
contrastAttribute, AttributeValueFilterSingleCondition.EQUALS, mapping.mapIndex(i));
4052

4153
ExampleSet conditionedSet = new ConditionedExampleSet(exampleSet, cnd);
42-
groupEstimators.add(new KaplanMeierEstimator(conditionedSet));
54+
SortedExampleSetEx cses = new SortedExampleSetEx(conditionedSet, survTime, SortedExampleSet.INCREASING);
55+
56+
groupEstimators.add(new KaplanMeierEstimator(cses));
4357
}
4458

4559
} catch (ExpressionEvaluationException e) {
@@ -49,6 +63,7 @@ public ContrastSurvivalExampleSet(SimpleExampleSet exampleSet) {
4963

5064
public ContrastSurvivalExampleSet(ContrastSurvivalExampleSet rhs) {
5165
super(rhs);
66+
this.contrastAttribute = rhs.contrastAttribute;
5267
this.trainingEstimator = rhs.trainingEstimator;
5368
this.groupEstimators = rhs.groupEstimators;
5469
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastSurvivalRuleSet.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ public class ContrastSurvivalRuleSet extends ContrastRuleSet {
4242
public ContrastSurvivalRuleSet(ExampleSet exampleSet, boolean isVoting, InductionParameters params, Knowledge knowledge) {
4343
super(exampleSet, isVoting, params, knowledge);
4444

45-
ContrastSurvivalExampleSet ces = (exampleSet instanceof ContrastExampleSet) ? (ContrastSurvivalExampleSet)exampleSet : null;
45+
ContrastSurvivalExampleSet ces = (exampleSet instanceof ContrastSurvivalExampleSet) ? (ContrastSurvivalExampleSet)exampleSet : null;
4646
if (ces == null) {
47-
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastExampleSet instances");
47+
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastSurvivalExampleSet instances");
4848
}
4949

5050
trainingEstimator = ces.getTrainingEstimator();

0 commit comments

Comments
 (0)