Skip to content

Commit 4e7e541

Browse files
committed
owcalibratedlearner: Never return base_learner instance from create_learner
The instance's name is modified in place after. Replace IdentityWrapper (which did not work) with copy.deepcopy of the instance.
1 parent 9497b39 commit 4e7e541

2 files changed

Lines changed: 6 additions & 8 deletions

File tree

Orange/widgets/model/owcalibratedlearner.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
from Orange.classification import CalibratedLearner, ThresholdLearner, \
24
NaiveBayesLearner
35
from Orange.data import Table
@@ -65,7 +67,6 @@ def set_learner(self, learner):
6567
self.learner = self.model = None
6668

6769
def _set_default_name(self):
68-
6970
if self.base_learner is None:
7071
self.set_default_learner_name("")
7172
else:
@@ -80,10 +81,6 @@ def calibration_options_changed(self):
8081
self.apply()
8182

8283
def create_learner(self):
83-
class IdentityWrapper(Learner):
84-
def fit_storage(self, data):
85-
return self.base_learner.fit_storage(data)
86-
8784
if self.base_learner is None:
8885
return None
8986
learner = self.base_learner
@@ -93,10 +90,11 @@ def fit_storage(self, data):
9390
if self.threshold != self.NoThresholdOptimization:
9491
learner = ThresholdLearner(learner,
9592
self.ThresholdMap[self.threshold])
93+
if learner is self.base_learner:
94+
learner = copy.deepcopy(learner)
9695
if self.preprocessors:
97-
if learner is self.base_learner:
98-
learner = IdentityWrapper()
9996
learner.preprocessors = (self.preprocessors, )
97+
assert learner is not self.base_learner
10098
return learner
10199

102100
def get_learner_parameters(self):

Orange/widgets/model/tests/test_owcalibratedlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_create_learner(self):
9595
widget.calibration = widget.NoCalibration
9696
widget.threshold = widget.NoThresholdOptimization
9797
learner = self.widget.create_learner()
98-
self.assertIs(learner, self.widget.base_learner)
98+
self.assertIsNot(learner, self.widget.base_learner)
9999

100100
widget.calibration = widget.SigmoidCalibration
101101
widget.threshold = widget.OptimizeF1

0 commit comments

Comments
 (0)