Skip to content

Commit 8ae2017

Browse files
authored
Merge pull request #2873 from ales-erjavec/feature-constructor-serialize
[FIX] Feature Constructor: Make FeatureFunc picklable
2 parents a00e5f0 + 95b62d2 commit 8ae2017

10 files changed

Lines changed: 108 additions & 68 deletions

File tree

Orange/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
import pickle
2-
from unittest.mock import patch
3-
# Needed because the pure-Python Unpickler that dill uses can also fail
4-
# with struct.error Exception. This seems to work, side effects unknown.
5-
with patch('pickle._Unpickler', pickle.Unpickler):
6-
import dill
7-
dill.settings['protocol'] = pickle.HIGHEST_PROTOCOL
8-
dill.settings['recurse'] = True
9-
dill.settings['byref'] = True
10-
111
from .misc.lazy_module import _LazyModule
122
from .misc.datasets import _DatasetInfo
133
from .version import \

Orange/widgets/data/owfeatureconstructor.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
import random
1515
import logging
1616
import ast
17+
import types
1718

1819
from traceback import format_exception_only
1920
from collections import namedtuple, OrderedDict
2021
from itertools import chain, count
22+
from typing import List, Dict, Any # pylint: disable=unused-import
2123

2224
import numpy as np
2325

@@ -632,9 +634,6 @@ def send_report(self):
632634
report.plural("Constructed feature{s}", len(items)), items)
633635

634636

635-
636-
637-
638637
def freevars(exp, env):
639638
"""
640639
Return names of all free variables in a parsed (expression) AST.
@@ -838,46 +837,76 @@ def bind_variable(descriptor, env):
838837
(descriptor, (instance -> value) | (table -> value list))
839838
"""
840839
if not descriptor.expression.strip():
841-
return (descriptor, lambda _: float("nan"))
840+
return descriptor, FeatureFunc("nan", [], {"nan": float("nan")})
842841

843842
exp_ast = ast.parse(descriptor.expression, mode="eval")
844843
freev = unique(freevars(exp_ast, []))
845844
variables = {sanitized_name(v.name): v for v in env}
846845
source_vars = [(name, variables[name]) for name in freev
847846
if name in variables]
848847

849-
values = []
848+
values = {}
850849
if isinstance(descriptor, DiscreteDescriptor):
851850
values = [sanitized_name(v) for v in descriptor.values]
852-
return descriptor, FeatureFunc(exp_ast, source_vars, values)
851+
values = {name: i for i, name in enumerate(values)}
852+
return descriptor, FeatureFunc(descriptor.expression, source_vars, values)
853853

854854

855-
def make_lambda(expression, args, values):
856-
def make_arg(name):
857-
if sys.version_info >= (3, 0):
858-
return ast.arg(arg=name, annotation=None)
859-
else:
860-
return ast.Name(id=name, ctx=ast.Param(), lineno=1, col_offset=0)
855+
def make_lambda(expression, args, env={}):
856+
# type: (ast.Expression, List[str], Dict[str, Any]) -> types.FunctionType
857+
"""
858+
Create an lambda function from a expression AST.
861859
860+
Parameters
861+
----------
862+
expression : ast.Expression
863+
The body of the lambda.
864+
args : List[str]
865+
A list of positional argument names
866+
env : Dict[str, Any]
867+
Extra environment to capture in the lambda's closure.
868+
869+
Returns
870+
-------
871+
func : types.FunctionType
872+
"""
873+
# lambda *{args}* : EXPRESSION
862874
lambda_ = ast.Lambda(
863875
args=ast.arguments(
864-
args=[make_arg(arg) for arg in args + values],
876+
args=[ast.arg(arg=arg, annotation=None) for arg in args],
865877
varargs=None,
866878
varargannotation=None,
867879
kwonlyargs=[],
868880
kwarg=None,
869881
kwargannotation=None,
870-
defaults=[ast.Num(i) for i in range(len(values))],
882+
defaults=[],
871883
kw_defaults=[]),
872884
body=expression.body,
873885
)
874886
lambda_ = ast.copy_location(lambda_, expression.body)
875-
exp = ast.Expression(body=lambda_, lineno=1, col_offset=0)
876-
ast.dump(exp)
887+
# lambda **{env}** : lambda *{args}*: EXPRESSION
888+
outer = ast.Lambda(
889+
args=ast.arguments(
890+
args=[ast.arg(arg=name, annotation=None) for name in env],
891+
varargs=None,
892+
varargannotation=None,
893+
kwonlyargs=[],
894+
kwarg=None,
895+
kwargannotation=None,
896+
defaults=[],
897+
kw_defaults=[],
898+
),
899+
body=lambda_,
900+
)
901+
exp = ast.Expression(body=outer, lineno=1, col_offset=0)
877902
ast.fix_missing_locations(exp)
878903
GLOBALS = __GLOBALS.copy()
879904
GLOBALS["__builtins__"] = {}
880-
return eval(compile(exp, "<lambda>", "eval"), GLOBALS)
905+
fouter = eval(compile(exp, "<lambda>", "eval"), GLOBALS)
906+
assert isinstance(fouter, types.FunctionType)
907+
finner = fouter(**env)
908+
assert isinstance(finner, types.FunctionType)
909+
return finner
881910

882911

883912
__ALLOWED = [
@@ -934,11 +963,26 @@ def make_arg(name):
934963

935964

936965
class FeatureFunc:
937-
def __init__(self, expression, args, values):
966+
"""
967+
Parameters
968+
----------
969+
expression : str
970+
An expression string
971+
args : List[Tuple[str, Orange.data.Variable]]
972+
A list of (`name`, `variable`) tuples where `name` is the name of
973+
a variable as used in `expression`, and `variable` is the variable
974+
instance used to extract the corresponding column/value from a
975+
Table/Instance.
976+
extra_env : Dict[str, Any]
977+
Extra environment specifying constant values to be made available
978+
in expression. It must not shadow names in `args`
979+
"""
980+
def __init__(self, expression, args, extra_env={}):
938981
self.expression = expression
939982
self.args = args
940-
self.values = values
941-
self.func = make_lambda(expression, [name for name, _ in args], values)
983+
self.extra_env = dict(extra_env)
984+
self.func = make_lambda(ast.parse(expression, mode="eval"),
985+
[name for name, _ in args], self.extra_env)
942986

943987
def __call__(self, instance, *_):
944988
if isinstance(instance, Orange.data.Table):
@@ -947,6 +991,12 @@ def __call__(self, instance, *_):
947991
args = [instance[var] for _, var in self.args]
948992
return self.func(*args)
949993

994+
def __reduce__(self):
995+
return type(self), (self.expression, self.args, self.extra_env)
996+
997+
def __repr__(self):
998+
return "{0.__name__}{1!r}".format(*self.__reduce__())
999+
9501000

9511001
def unique(seq):
9521002
seen = set()
@@ -958,7 +1008,7 @@ def unique(seq):
9581008
return unique_el
9591009

9601010

961-
def main(argv=None):
1011+
def main(argv=None): # pragma: no cover
9621012
from AnyQt.QtWidgets import QApplication
9631013
if argv is None:
9641014
argv = sys.argv
@@ -981,5 +1031,6 @@ def main(argv=None):
9811031
w.saveSettings()
9821032
return 0
9831033

1034+
9841035
if __name__ == "__main__":
9851036
sys.exit(main())

Orange/widgets/data/tests/test_owfeatureconstructor.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import ast
33
import sys
44
import math
5+
import pickle
6+
import copy
57

68
import numpy as np
79

@@ -14,9 +16,9 @@
1416
construct_variables, OWFeatureConstructor,
1517
FeatureEditor, DiscreteFeatureEditor)
1618

17-
from Orange.widgets.data.owfeatureconstructor import freevars, validate_exp
18-
19-
import dill as pickle # Import dill after Orange because patched
19+
from Orange.widgets.data.owfeatureconstructor import (
20+
freevars, validate_exp, FeatureFunc
21+
)
2022

2123

2224
class FeatureConstructorTest(unittest.TestCase):
@@ -89,27 +91,6 @@ def test_construct_numeric_names(self):
8991
ContinuousVariable._clear_all_caches()
9092

9193

92-
GLOBAL_CONST = 2
93-
94-
95-
class PicklingTest(unittest.TestCase):
96-
CLASS_CONST = 3
97-
98-
def test_lambdas_pickle(self):
99-
NONLOCAL_CONST = 5
100-
101-
lambda_func = lambda x, local_const=7: \
102-
x * local_const * NONLOCAL_CONST * self.CLASS_CONST * GLOBAL_CONST
103-
104-
def nested_func(x, local_const=7):
105-
return x * local_const * NONLOCAL_CONST * self.CLASS_CONST * GLOBAL_CONST
106-
107-
self.assertEqual(lambda_func(11),
108-
pickle.loads(pickle.dumps(lambda_func))(11))
109-
self.assertEqual(nested_func(11),
110-
pickle.loads(pickle.dumps(nested_func))(11))
111-
112-
11394
class TestTools(unittest.TestCase):
11495
def test_free_vars(self):
11596
stmt = ast.parse("foo", "", "single")
@@ -218,6 +199,30 @@ def validate_(source):
218199
validate_("{a:1 for a in s}")
219200

220201

202+
class FeatureFuncTest(unittest.TestCase):
203+
def test_reconstruct(self):
204+
f = FeatureFunc("a * x + c", [("x", "x")], {"a": 2, "c": 10})
205+
self.assertEqual(f({"x": 2}), 14)
206+
f1 = pickle.loads(pickle.dumps(f))
207+
self.assertEqual(f1({"x": 2}), 14)
208+
fc = copy.copy(f)
209+
self.assertEqual(fc({"x": 3}), 16)
210+
211+
def test_repr(self):
212+
self.assertEqual(repr(FeatureFunc("a + 1", [("a", 2)])),
213+
"FeatureFunc('a + 1', [('a', 2)], {})")
214+
215+
def test_call(self):
216+
f = FeatureFunc("a + 1", [("a", "a")])
217+
self.assertEqual(f({"a": 2}), 3)
218+
219+
iris = Table("iris")
220+
f = FeatureFunc("sepal_width + 10",
221+
[("sepal_width", iris.domain["sepal width"])])
222+
r = f(iris)
223+
np.testing.assert_array_equal(r, iris.X[:, 1] + 10)
224+
225+
221226
class OWFeatureConstructorTests(WidgetTest):
222227
def setUp(self):
223228
self.widget = OWFeatureConstructor()

Orange/widgets/model/owloadmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2+
import pickle
23

3-
import dill as pickle
44
from AnyQt.QtCore import QTimer
55
from AnyQt.QtWidgets import (
66
QSizePolicy, QHBoxLayout, QComboBox, QStyle, QFileDialog

Orange/widgets/model/owsavemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

3-
import dill as pickle
3+
import pickle
4+
45
from AnyQt.QtWidgets import (
56
QComboBox, QStyle, QSizePolicy, QFileDialog, QApplication
67
)

Orange/widgets/model/tests/test_owloadmodel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import pickle
55
from tempfile import mkstemp
66

7-
import dill # Import dill after Orange because patched
8-
97
from Orange.classification.majority import ConstantModel
108
from Orange.widgets.model.owloadmodel import OWLoadModel
119
from Orange.widgets.tests.base import WidgetTest
@@ -23,11 +21,10 @@ def test_show_error(self):
2321
fd, fname = mkstemp(suffix='.pkcls')
2422
os.close(fd)
2523
try:
26-
for pickle_impl in (pickle, dill):
27-
with open(fname, 'wb') as f:
28-
pickle_impl.dump(clsf, f)
29-
self.widget.load(fname)
30-
self.assertFalse(self.widget.Error.load_error.is_shown())
24+
with open(fname, 'wb') as f:
25+
pickle.dump(clsf, f)
26+
self.widget.load(fname)
27+
self.assertFalse(self.widget.Error.load_error.is_shown())
3128

3229
with open(fname, "w") as f:
3330
f.write("X")

conda-recipe/meta.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ requirements:
4848
- anyqt >=0.0.6
4949
- joblib
5050
- python.app # [osx]
51-
- dill # pickle anything
5251
- commonmark
5352
- serverfiles
5453

requirements-core.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,5 @@ chardet>=3.0.2
1111
joblib>=0.9.4
1212
keyring
1313
keyrings.alt # for alternative keyring implementations
14-
dill
1514
setuptools>=36.3
1615
serverfiles # for Data Sets synchronization

scripts/macos/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
AnyQt==0.0.8
88
Bottleneck==1.2.0
99
chardet==3.0.4
10-
dill==0.2.6
1110
docutils==0.13.1
1211
joblib==0.11
1312
keyring==10.3.1

scripts/windows/specs/PY34-win32.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ AnyQt==0.0.8
1818
PyQt5==5.5.1
1919
docutils==0.13.1
2020
pip==9.0.1
21-
dill==0.2.6
2221
pyqtgraph==0.10.0
2322
six==1.10.0
2423
xlrd==1.0.0

0 commit comments

Comments
 (0)