@@ -32,8 +32,6 @@ public ConditionCandidate(String attribute, IValueSet valueSet) {
32
32
}
33
33
}
34
34
35
- protected static final int MAX_BINS = 100 ;
36
-
37
35
// Example description:
38
36
// [0-31] - example id (32 bits)
39
37
// [32-47] - block id (16 bits)
@@ -73,27 +71,36 @@ public ApproximateClassificationFinder(InductionParameters params) {
73
71
}
74
72
75
73
@ Override
76
- public void preprocess (ExampleSet dataset ) {
74
+ public ExampleSet preprocess (ExampleSet dataset ) {
77
75
int n_examples = dataset .size ();
78
76
int n_attributes = dataset .getAttributes ().size ();
79
77
80
78
trainSet = dataset ;
81
79
descriptions = new long [n_attributes ][n_examples ];
82
80
mappings = new int [n_attributes ][n_examples ];
83
81
84
- bins_positives = new int [n_attributes ][MAX_BINS ];
85
- bins_negatives = new int [n_attributes ][MAX_BINS ];
86
- bins_newPositives = new int [n_attributes ][MAX_BINS ];
87
- bins_begins = new int [n_attributes ][MAX_BINS ];
82
+ bins_positives = new int [n_attributes ][];
83
+ bins_negatives = new int [n_attributes ][];
84
+ bins_newPositives = new int [n_attributes ][];
85
+ bins_begins = new int [n_attributes ][];
88
86
89
87
ruleRanges = new int [n_attributes ][2 ];
90
88
91
89
for (Attribute attr : dataset .getAttributes ()) {
92
90
int ia = attr .getTableIndex ();
91
+ int n_vals = attr .isNominal () ? attr .getMapping ().size () : params .getApproximateBinsCount ();
92
+
93
+ bins_positives [ia ] = new int [n_vals ];
94
+ bins_negatives [ia ] = new int [n_vals ];
95
+ bins_newPositives [ia ] = new int [n_vals ];
96
+ bins_begins [ia ] = new int [n_vals ];
93
97
94
98
determineBins (dataset , attr , descriptions [ia ], mappings [ia ], bins_begins [ia ], ruleRanges [ia ]);
99
+
95
100
arrayCopies .put ("ruleRanges" , (Object )Arrays .stream (ruleRanges ).map (int []::clone ).toArray (int [][]::new ));
96
101
}
102
+
103
+ return dataset ;
97
104
}
98
105
99
106
/**
@@ -293,13 +300,14 @@ protected ElementaryCondition induceCondition(
293
300
int covered_n = 0 ;
294
301
int covered_new_p = 0 ;
295
302
296
- // use first attribute to establish number of covered elements
303
+ // use first attribute to establish number of covered elements
297
304
for (int bid = ruleRanges [0 ][0 ]; bid < ruleRanges [0 ][1 ]; ++bid ) {
298
305
covered_p += bins_positives [0 ][bid ];
299
306
covered_n += bins_negatives [0 ][bid ];
300
307
covered_new_p += bins_newPositives [0 ][bid ];
301
308
}
302
309
310
+
303
311
// iterate over all allowed decision attributes
304
312
for (Attribute attr : dataset .getAttributes ()) {
305
313
@@ -462,7 +470,10 @@ class Stats {
462
470
463
471
if (current != null && current .getAttribute () != null ) {
464
472
Logger .log ("\t Attribute best: " + current + ", quality=" + current .quality , Level .FINEST );
465
- updateMidpoint (dataset , current );
473
+ Attribute attr = dataset .getAttributes ().get (current .getAttribute ());
474
+ if (attr .isNumerical ()) {
475
+ updateMidpoint (dataset , current );
476
+ }
466
477
Logger .log (", adjusted: " + current + "\n " , Level .FINEST );
467
478
}
468
479
@@ -482,13 +493,13 @@ class Stats {
482
493
return null ; // empty condition - discard
483
494
}
484
495
485
- updateMidpoint (dataset , best );
486
-
487
- Logger .log ("\t Final best: " + best + ", quality=" + best .quality + "\n " , Level .FINEST );
488
-
489
- if (bestAttr .isNominal ()) {
496
+ if (bestAttr .isNumerical ()) {
497
+ updateMidpoint (dataset , best );
498
+ } else {
490
499
allowedAttributes .remove (bestAttr );
491
500
}
501
+
502
+ Logger .log ("\t Final best: " + best + ", quality=" + best .quality + "\n " , Level .FINEST );
492
503
}
493
504
494
505
return best ;
@@ -508,7 +519,7 @@ protected void notifyConditionAdded(ConditionBase cnd) {
508
519
ruleRanges [aid ][0 ] = blockId + 1 ;
509
520
ruleRanges [aid ][1 ] = blockId ;
510
521
} else {
511
- excludeExamplesFromArrays (trainSet , attr , ruleRanges [aid ][0 ], candidate .blockId + 1 );
522
+ excludeExamplesFromArrays (trainSet , attr , ruleRanges [aid ][0 ], candidate .blockId );
512
523
excludeExamplesFromArrays (trainSet , attr , candidate .blockId + 1 , ruleRanges [aid ][1 ]);
513
524
ruleRanges [aid ][0 ] = blockId ;
514
525
ruleRanges [aid ][1 ] = blockId + 1 ;
@@ -546,6 +557,7 @@ protected void determineBins(ExampleSet dataset, Attribute attr,
546
557
vals [i ] = dataset .getExample (i ).getValue (attr );
547
558
}
548
559
560
+
549
561
/*
550
562
class ValuesComparator implements IntComparator {
551
563
double [] vals;
@@ -597,12 +609,12 @@ public int compare(Bin p, Bin q) {
597
609
}
598
610
}
599
611
600
- PriorityQueue <Bin > bins = new PriorityQueue <Bin >(100 , new SizeBinComparator ());
601
- PriorityQueue <Bin > finalBins = new PriorityQueue <Bin >(100 , new IndexBinComparator ());
612
+ PriorityQueue <Bin > bins = new PriorityQueue <Bin >(binsBegins . length , new SizeBinComparator ());
613
+ PriorityQueue <Bin > finalBins = new PriorityQueue <Bin >(binsBegins . length , new IndexBinComparator ());
602
614
603
615
bins .add (new Bin (0 , mappings .length ));
604
616
605
- while (bins .size () > 0 && (bins .size () + finalBins .size ()) < MAX_BINS ) {
617
+ while (bins .size () > 0 && (bins .size () + finalBins .size ()) < binsBegins . length ) {
606
618
Bin b = bins .poll ();
607
619
608
620
int id = (b .end + b .begin ) / 2 ;
@@ -611,9 +623,13 @@ public int compare(Bin p, Bin q) {
611
623
// decide direction
612
624
if (vals [b .begin ] == midval ) {
613
625
// go up
614
- while (vals [id ] == midval ) { ++id ; }
626
+ while (vals [id ] == midval ) {
627
+ ++id ;
628
+ }
615
629
} else {
616
- while (vals [id - 1 ] == midval ) { --id ; }
630
+ while (vals [id - 1 ] == midval ) {
631
+ --id ;
632
+ }
617
633
}
618
634
619
635
Bin leftBin = new Bin (b .begin , id );
@@ -646,17 +662,16 @@ public int compare(Bin p, Bin q) {
646
662
descriptions [i ] |= bid << OFFSET_BIN ;
647
663
}
648
664
649
- binsBegins [(int )bid ] = b .begin ;
665
+ binsBegins [(int ) bid ] = b .begin ;
650
666
++bid ;
651
667
}
652
668
653
669
ruleRanges [0 ] = 0 ;
654
- ruleRanges [1 ] = (int )bid ;
655
-
656
- // print bins
657
- for (int i = 0 ; i < bid ; ++i ) {
670
+ ruleRanges [1 ] = (int ) bid ;
671
+ // print bins
672
+ for (int i = 0 ; i < ruleRanges [1 ]; ++i ) {
658
673
int lo = binsBegins [i ];
659
- int hi = (i == bid - 1 ) ? trainSet .size () : binsBegins [i +1 ] - 1 ;
674
+ int hi = (i == ruleRanges [ 1 ] - 1 ) ? trainSet .size () : binsBegins [i +1 ] - 1 ;
660
675
Logger .log ("[" + lo + ", " + hi + "]:" + vals [lo ] + "\n " , Level .FINER );
661
676
}
662
677
}
@@ -665,6 +680,10 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
665
680
666
681
Logger .log ("Excluding examples: " + attr .getName () + " from [" + binLo + "," + binHi + "]\n " , Level .FINER );
667
682
683
+ if (binLo == binHi ) {
684
+ return ;
685
+ }
686
+
668
687
int n_examples = dataset .size ();
669
688
int src_row = attr .getTableIndex ();
670
689
long [] src_descriptions = descriptions [src_row ];
@@ -695,9 +714,11 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
695
714
int dst_row = other .getTableIndex ();
696
715
697
716
// if nominal attribute was already used
717
+ /*
698
718
if (other.isNominal() && Math.abs(ruleRanges[dst_row][1] - ruleRanges[dst_row][0]) == 1) {
699
719
continue;
700
720
}
721
+ */
701
722
702
723
Future <Object > future = pool .submit (() -> {
703
724
@@ -717,8 +738,14 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
717
738
718
739
int bid = (int ) ((desc & MASK_BIN ) >> OFFSET_BIN );
719
740
741
+ boolean opposite = dst_ranges [0 ] > dst_ranges [1 ]; // this indicate nominal opposite condition
742
+ int dst_bin_lo = Math .min (dst_ranges [0 ], dst_ranges [1 ]);
743
+ int dst_bin_hi = Math .max (dst_ranges [0 ], dst_ranges [1 ]);
744
+
720
745
// update stats only in bins covered by the rule
721
- if (bid >= dst_ranges [0 ] && bid < dst_ranges [1 ] && ((desc & FLAG_COVERED ) != 0 )) {
746
+ boolean in_range = (bid >= dst_bin_lo && bid < dst_bin_hi ) || (opposite && (bid < dst_bin_lo || bid >= dst_bin_hi ));
747
+
748
+ if (in_range && ((desc & FLAG_COVERED ) != 0 )) {
722
749
723
750
if ((desc & FLAG_POSITIVE ) != 0 ) {
724
751
--dst_positives [bid ];
@@ -755,12 +782,16 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {
755
782
756
783
int n_examples = dataset .size ();
757
784
785
+ int [][] copy_ranges = (int [][])arrayCopies .get ("ruleRanges" );
786
+
758
787
for (Attribute attr : dataset .getAttributes ()) {
759
788
int attribute_id = attr .getTableIndex ();
760
789
761
790
Arrays .fill (bins_positives [attribute_id ], 0 );
762
791
Arrays .fill (bins_negatives [attribute_id ], 0 );
763
792
Arrays .fill (bins_newPositives [attribute_id ], 0 );
793
+ ruleRanges [attribute_id ][0 ] = 0 ;
794
+ ruleRanges [attribute_id ][1 ] = copy_ranges [attribute_id ][1 ];
764
795
765
796
long [] descriptions_row = descriptions [attribute_id ];
766
797
int [] mappings_row = mappings [attribute_id ];
@@ -792,6 +823,9 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {
792
823
}
793
824
}
794
825
826
+ // reset rule ranges
827
+
828
+
795
829
Logger .log ("Reset arrays for class " + targetLabel + "\n " , Level .FINER );
796
830
printArrays ();
797
831
@@ -816,9 +850,13 @@ protected void printArrays() {
816
850
817
851
int bin_p = 0 , bin_n = 0 , bin_new_p = 0 , bin_outside = 0 ;
818
852
819
- for (int i = 0 ; i < MAX_BINS ; ++i ) {
853
+ boolean opposite = ruleRanges [attribute_id ][0 ] > ruleRanges [attribute_id ][1 ]; // this indicate nominal opposite condition
854
+ int lo = Math .min (ruleRanges [attribute_id ][0 ], ruleRanges [attribute_id ][1 ]);
855
+ int hi = Math .max (ruleRanges [attribute_id ][0 ], ruleRanges [attribute_id ][1 ]);
856
+
857
+ for (int i = 0 ; i < bins_positives [attribute_id ].length ; ++i ) {
820
858
821
- if (i >= ruleRanges [ attribute_id ][ 0 ] && i < ruleRanges [ attribute_id ][ 1 ] ) {
859
+ if (( i >= lo && i < hi ) || ( opposite && ( i < lo || i >= hi )) ) {
822
860
bin_p += bins_positives [attribute_id ][i ];
823
861
bin_n += bins_negatives [attribute_id ][i ];
824
862
bin_new_p += bins_newPositives [attribute_id ][i ];
0 commit comments