Skip to content

Commit acd799e

Browse files
committed
RUL-107: Bugfix in parsing expert survival rules when survival status attribute is nominal.
1 parent c2a3b1d commit acd799e

File tree

7 files changed

+94
-49
lines changed

7 files changed

+94
-49
lines changed

adaa.analytics.rules/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ plugins {
55
id 'java'
66
}
77

8-
version = '2.1.20'
8+
version = '2.1.21'
99
java {
1010
sourceCompatibility = JavaVersion.VERSION_1_8
1111
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/Knowledge.java

+9-10
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import adaa.analytics.rules.logic.representation.condition.ConditionBase;
2020
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
2121
import adaa.analytics.rules.logic.representation.rule.Rule;
22+
import adaa.analytics.rules.logic.representation.rule.SurvivalRule;
2223
import adaa.analytics.rules.logic.representation.valueset.IValueSet;
2324
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
25+
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
2426
import adaa.analytics.rules.logic.representation.valueset.Universum;
2527

2628
import java.io.Serializable;
@@ -60,10 +62,7 @@ public class Knowledge implements Serializable {
6062

6163
/** Maximum number of preferred attributes per rule. */
6264
protected int preferredAttributesPerRule;
63-
64-
/** Auxiliary files indicating whether the knowledge concerns regression problem. */
65-
protected boolean isRegression;
66-
65+
6766
/** Auxiliary files indicating number classes (classification problems only). */
6867
protected int numClasses;
6968

@@ -166,7 +165,7 @@ public void setPreferredAttributesPerRule(int v) {
166165
*/
167166
public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> preferredConditions, MultiSet<Rule> forbiddenConditions) {
168167

169-
this.isRegression = dataset.getAttributes().getLabel().isNumerical();
168+
boolean isSurvival = dataset.getAttributes().getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null;
170169

171170
this.extendUsingPreferred = false;
172171
this.extendUsingAutomatic = false;
@@ -176,7 +175,7 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe
176175
this.preferredConditionsPerRule = Integer.MAX_VALUE;
177176
this.preferredAttributesPerRule = Integer.MAX_VALUE;
178177

179-
int numClasses = (dataset.getAttributes().getLabel().isNominal())
178+
this.numClasses = (dataset.getAttributes().getLabel().isNominal() && !isSurvival)
180179
? dataset.getAttributes().getLabel().getMapping().size() : 1;
181180

182181
for (int i = 0; i < numClasses; ++i) {
@@ -188,14 +187,14 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe
188187
}
189188

190189
for (Rule r : rules) {
191-
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
192-
int c = (int)set.getValue();
190+
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
191+
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();
193192
this.rules.get(c).add(r);
194193
}
195194

196195
for (Rule r : preferredConditions) {
197196
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
198-
int c = (int)set.getValue();
197+
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();
199198

200199
ElementaryCondition ec = (ElementaryCondition) r.getPremise().getSubconditions().get(0);
201200
if (ec.getValueSet() instanceof Universum) {
@@ -207,7 +206,7 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe
207206

208207
for (Rule r : forbiddenConditions) {
209208
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
210-
int c = (int)set.getValue();
209+
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();
211210

212211
ElementaryCondition ec = (ElementaryCondition) r.getPremise().getSubconditions().get(0);
213212
if (ec.getValueSet() instanceof Universum) {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/RuleParser.java

+39-29
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323
import adaa.analytics.rules.logic.representation.rule.RegressionRule;
2424
import adaa.analytics.rules.logic.representation.rule.Rule;
2525
import adaa.analytics.rules.logic.representation.rule.SurvivalRule;
26-
import adaa.analytics.rules.logic.representation.valueset.IValueSet;
27-
import adaa.analytics.rules.logic.representation.valueset.Interval;
28-
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
29-
import adaa.analytics.rules.logic.representation.valueset.Universum;
26+
import adaa.analytics.rules.logic.representation.valueset.*;
3027
import adaa.analytics.rules.utils.Logger;
3128
import org.apache.commons.lang3.math.NumberUtils;
3229

3330
import java.util.ArrayList;
31+
import java.util.Collections;
3432
import java.util.List;
3533
import java.util.logging.Level;
3634
import java.util.regex.Matcher;
@@ -57,10 +55,7 @@ public static Rule parseRule(String s, IAttributes meta) {
5755
Pattern pattern = Pattern.compile("IF\\s+(?<premise>.+)\\s+THEN(?<consequence>\\s+.*|\\s*)");
5856
Matcher matcher = pattern.matcher(s);
5957

60-
boolean isSurvival = false;
61-
if (meta.getColumnByRoleUnsafe(SurvivalRule.SURVIVAL_TIME_ROLE) != null) {
62-
isSurvival = true;
63-
}
58+
boolean isSurvival = (meta.getColumnByRoleUnsafe(SurvivalRule.SURVIVAL_TIME_ROLE) != null);
6459

6560
if (matcher.find()) {
6661
String pre = matcher.group("premise");
@@ -70,24 +65,28 @@ public static Rule parseRule(String s, IAttributes meta) {
7065
CompoundCondition premise = parseCompoundCondition(pre, meta);
7166

7267
if (con == null || con.trim().length() == 0) {
73-
if (!meta.getLabelUnsafe().isNumerical()) {
74-
Logger.log("Empty conclusion for nominal label"+ "\n", Level.WARNING);
75-
} else {
76-
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new SingletonSet(NaN, null));
68+
if (isSurvival) {
69+
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new UndefinedSet());
70+
} else if (meta.getLabelUnsafe().isNumerical()) {
71+
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new SingletonSet(NaN, null));
7772
consequence.setAdjustable(false);
78-
consequence.setDisabled( false);
73+
consequence.setDisabled(false);
74+
} else{
75+
Logger.log("Empty conclusion for nominal label"+ "\n", Level.WARNING);
7976
}
8077
} else {
8178
consequence = parseElementaryCondition(con, meta);
8279
}
8380

8481
if (premise != null && consequence != null) {
8582

86-
rule = meta.getLabelUnsafe().isNominal()
87-
? new ClassificationRule(premise, consequence)
88-
: (isSurvival
89-
? new SurvivalRule(premise, consequence)
90-
: new RegressionRule(premise, consequence));
83+
if (isSurvival) {
84+
rule = new SurvivalRule(premise, consequence);
85+
} else {
86+
rule = meta.getLabelUnsafe().isNominal()
87+
? new ClassificationRule(premise, consequence)
88+
: new RegressionRule(premise, consequence);
89+
}
9190
}
9291
}
9392

@@ -158,9 +157,11 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
158157
}
159158

160159
IValueSet valueSet = null;
160+
IAttribute attributeMeta = meta.get(attribute);
161+
162+
boolean isSurvival = (meta.getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null) && (meta.getLabel() == attributeMeta);
161163

162-
IAttribute attributeMeta = meta.get(attribute);
163-
if (attributeMeta == null) {
164+
if (attributeMeta == null) {
164165
Logger.log("Attribute <" + attribute + "> not found"+ "\n", Level.WARNING);
165166
return null;
166167
}
@@ -176,13 +177,19 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
176177
matcher = regex.matcher(valueString);
177178
if (matcher.find()) {
178179
String value = matcher.group("discrete");
179-
List<String> mapping = new ArrayList<String>(attributeMeta.getMapping().getValues());
180-
double v = mapping.indexOf(value);
181-
if (v == -1) {
182-
Logger.log("Invalid value <" + value + "> of the nominal attribute <" + attribute + ">"+ "\n", Level.WARNING);
183-
return null;
184-
}
185-
valueSet = new SingletonSet(v, mapping);
180+
181+
if (value.equals("NaN") && isSurvival) {
182+
valueSet = new UndefinedSet();
183+
} else {
184+
185+
List<String> mapping = new ArrayList<String>(attributeMeta.getMapping().getValues());
186+
double v = mapping.indexOf(value);
187+
if (v == -1) {
188+
Logger.log("Invalid value <" + value + "> of the nominal attribute <" + attribute + ">" + "\n", Level.WARNING);
189+
return null;
190+
}
191+
valueSet = new SingletonSet(v, mapping);
192+
}
186193

187194
}
188195
} else if (attributeMeta.isNumerical()) {
@@ -191,8 +198,11 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
191198
//
192199
if (matcher.find()) {
193200
String value = matcher.group("discrete");
194-
double v = value.equals("NaN") ? Double.NaN : Double.parseDouble(value);
195-
valueSet = new SingletonSet(v, null);
201+
if (value.equals("NaN")) {
202+
valueSet = new UndefinedSet();
203+
} else {
204+
valueSet = new SingletonSet(Double.parseDouble(value), null);
205+
}
196206
} else {
197207
boolean leftClosed = Pattern.compile("\\<.+").matcher(valueString).find();
198208
boolean rightClosed = Pattern.compile(".+\\>").matcher(valueString).find();

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/rule/Rule.java

+2-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
2525
import adaa.analytics.rules.logic.representation.IntegerBitSet;
2626
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
27+
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
2728
import org.jetbrains.annotations.NotNull;
2829

2930
import java.io.Serializable;
@@ -285,14 +286,7 @@ public void covers(IExampleSet set, ContingencyTable ct, Set<Integer> positives,
285286
* @return Text representation.
286287
*/
287288
public String toString() {
288-
String consequenceString;
289-
if (consequence.getValueSet() instanceof SingletonSet &&
290-
Double.isNaN(((SingletonSet) consequence.getValueSet()).getValue()) && ((SingletonSet) consequence.getValueSet()).getMapping() == null) {
291-
consequenceString = "";
292-
} else {
293-
consequenceString = consequence.toString();
294-
}
295-
String s = "IF " + premise.toString() + " THEN " + consequenceString;
289+
String s = "IF " + premise.toString() + " THEN " + consequence.toString();
296290
return s;
297291
}
298292

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/rule/SurvivalRule.java

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import adaa.analytics.rules.logic.representation.condition.CompoundCondition;
2323
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
2424
import adaa.analytics.rules.logic.representation.KaplanMeierEstimator;
25+
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
2526
import org.jetbrains.annotations.NotNull;
2627
import tech.tablesaw.api.DoubleColumn;
2728

@@ -157,5 +158,11 @@ public Covering covers(IExampleSet set) {
157158
}
158159
return covered;
159160
}
161+
162+
@Override
163+
public String toString() {
164+
String s = "IF " + premise.toString() + " THEN ";
165+
return s;
166+
}
160167

161168
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package adaa.analytics.rules.logic.representation.valueset;
2+
3+
import adaa.analytics.rules.utils.DoubleFormatter;
4+
5+
import java.util.List;
6+
7+
public class UndefinedSet extends SingletonSet{
8+
9+
/** Gets {@link #value} */
10+
public double getValue() { throw new RuntimeException("Illegal call for UndefinedSet: getValue"); }
11+
/** Sets {@link #value} */
12+
public void setValue(double v) { throw new RuntimeException("Illegal call for UndefinedSet: setValue"); }
13+
14+
/** Gets {@link #value} as string */
15+
public String getValueAsString() { throw new RuntimeException("Illegal call for UndefinedSet: getValueAsString"); }
16+
17+
/** Gets {@link #mapping} */
18+
public List<String> getMapping() { throw new RuntimeException("Illegal call for UndefinedSet: getMapping" ); }
19+
/** Sets {@link #mapping} */
20+
public void setMapping(List<String> v) { throw new RuntimeException("Illegal call for UndefinedSet: setMapping"); }
21+
22+
public UndefinedSet() {
23+
super(Double.NaN, null);
24+
}
25+
26+
@Override
27+
public String toString() {
28+
return "";
29+
}
30+
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/rulegenerator/ExpertRule.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,15 @@ public RuleSetBase learnWithExpert(IExampleSet exampleSet) {
202202
*/
203203
private void fixMappings(Iterable<Rule> rules, IExampleSet set) {
204204

205+
boolean isSurvival = (set.getAttributes().getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null);
206+
205207
for (Rule r : rules) {
206208
List<ConditionBase> toCheck = new ArrayList<ConditionBase>(); // list of elementary conditions to check
207209
toCheck.addAll(r.getPremise().getSubconditions());
208-
toCheck.add(r.getConsequence());
210+
211+
if (!isSurvival) {
212+
toCheck.add(r.getConsequence());
213+
}
209214

210215
for (ConditionBase c : toCheck) {
211216
ElementaryCondition ec = (c instanceof ElementaryCondition) ? (ElementaryCondition) c : null;

0 commit comments

Comments
 (0)