Skip to content

Commit dab0cd8

Browse files
authored
Merge pull request #503 from moorepants/cache-binary
Enables caching of the compiled code if tmp_dir is set.
2 parents 40246b6 + 4909f30 commit dab0cd8

File tree

3 files changed

+166
-48
lines changed

3 files changed

+166
-48
lines changed

opty/direct_collocation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,12 @@ def __init__(self, equations_of_motion, state_symbols,
11491149
tmp_dir : string, optional
11501150
If you want to see the generated Cython and C code for the
11511151
constraint and constraint Jacobian evaluations, pass in a path to a
1152-
directory here.
1152+
directory here. Additionally, if this temporary directory is set to
1153+
an existing populated directory and the equations of motion have
1154+
not changed relative to prior instantiations of this class, the
1155+
compilation step will be skipped if equivalent compiled modules are
1156+
already present and cached. This may save significant computational
1157+
time when repeatedly using the same set of equations of motion.
11531158
integration_method : string, optional
11541159
The integration method to use, either ``backward euler`` or
11551160
``midpoint``.

opty/tests/test_direct_collocation.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import shutil
3+
import tempfile
14
from collections import OrderedDict
25

36
import numpy as np
@@ -408,43 +411,81 @@ def test_pendulum():
408411
np.testing.assert_allclose(prob._upp_con_bounds, expected_upp_con_bounds)
409412

410413

411-
def test_Problem():
414+
def TestProblem():
412415

413-
m, c, k, t = sym.symbols('m, c, k, t')
414-
x, v, f = [s(t) for s in sym.symbols('x, v, f', cls=sym.Function)]
416+
def setup_method(self):
417+
self.tmp_dir = tempfile.mkdtemp("opty_cache_test")
418+
if os.path.exists(self.tmp_dir):
419+
shutil.rmtree(self.tmp_dir)
420+
os.mkdir(self.tmp_dir)
421+
422+
def teardown_method(self):
423+
if os.path.exists(self.tmp_dir):
424+
shutil.rmtree(self.tmp_dir)
425+
426+
def test_problem(self):
415427

416-
state_symbols = (x, v)
428+
m, c, k, t = sym.symbols('m, c, k, t')
429+
x, v, f = [s(t) for s in sym.symbols('x, v, f', cls=sym.Function)]
417430

418-
interval_value = 0.01
431+
state_symbols = (x, v)
419432

420-
eom = sym.Matrix([x.diff() - v,
421-
m * v.diff() + c * v + k * x - f])
433+
interval_value = 0.01
422434

423-
prob = Problem(lambda x: 1.0,
424-
lambda x: x,
425-
eom,
426-
state_symbols,
427-
2,
428-
interval_value,
429-
time_symbol=t,
430-
bounds={x: (-10.0, 10.0),
431-
f: (-8.0, 8.0),
432-
m: (-1.0, 1.0),
433-
c: (-0.5, 0.5)})
435+
eom = sym.Matrix([x.diff() - v,
436+
m * v.diff() + c * v + k * x - f])
434437

435-
INF = 10e19
436-
expected_lower = np.array([-10.0, -10.0,
437-
-INF, -INF,
438-
-8.0, -8.0,
439-
-0.5, -INF, -1.0])
440-
np.testing.assert_allclose(prob.lower_bound, expected_lower)
441-
expected_upper = np.array([10.0, 10.0,
442-
INF, INF,
443-
8.0, 8.0,
444-
0.5, INF, 1.0])
445-
np.testing.assert_allclose(prob.upper_bound, expected_upper)
438+
prob = Problem(
439+
lambda x: 1.0,
440+
lambda x: x,
441+
eom,
442+
state_symbols,
443+
2,
444+
interval_value,
445+
time_symbol=t,
446+
bounds={
447+
x: (-10.0, 10.0),
448+
f: (-8.0, 8.0),
449+
m: (-1.0, 1.0),
450+
c: (-0.5, 0.5),
451+
},
452+
tmp_dir=self.tmp_dir)
453+
454+
# Only two modules should be generated
455+
c_file_list = [f for f in os.listdir(self.tmp_dir) if
456+
f.endswith('_c.c')]
457+
assert len(c_file_list) == 2
458+
459+
INF = 10e19
460+
expected_lower = np.array([-10.0, -10.0,
461+
-INF, -INF,
462+
-8.0, -8.0,
463+
-0.5, -INF, -1.0])
464+
np.testing.assert_allclose(prob.lower_bound, expected_lower)
465+
expected_upper = np.array([10.0, 10.0,
466+
INF, INF,
467+
8.0, 8.0,
468+
0.5, INF, 1.0])
469+
np.testing.assert_allclose(prob.upper_bound, expected_upper)
470+
471+
assert prob.collocator.num_instance_constraints == 0
472+
473+
# run Problem again to see if the cache worked.
474+
prob = Problem(
475+
lambda x: 1.0,
476+
lambda x: x,
477+
eom,
478+
state_symbols,
479+
4,
480+
interval_value,
481+
time_symbol=t,
482+
tmp_dir=self.tmp_dir,
483+
)
446484

447-
assert prob.collocator.num_instance_constraints == 0
485+
# no more C files should have been generated
486+
c_file_list = [f for f in os.listdir(self.tmp_dir) if
487+
f.endswith('_c.c')]
488+
assert len(c_file_list) == 2
448489

449490

450491
class TestConstraintCollocator():

opty/utils.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from timeit import default_timer as timer
1616
import logging
1717
import locale
18+
import hashlib
1819

1920
import numpy as np
2021
import sympy as sm
@@ -79,6 +80,11 @@ def ccode(expr, assign_to=None, **settings):
7980

8081
def _forward_jacobian(expr, wrt):
8182

83+
# NOTE : free_symbols are sets and are not guaranteed to be in the same
84+
# order, so sympy.ordered() is used throughout to ensure a deterministic
85+
# behavior. This is important for the binary caching to work as it hashes
86+
# the generated code string.
87+
8288
def add_to_cache(node):
8389
if node in expr_to_replacement_cache:
8490
replacement_symbol = expr_to_replacement_cache[node]
@@ -120,6 +126,8 @@ def add_to_cache(node):
120126
replacement_symbols = numbered_symbols(
121127
prefix='z',
122128
cls=sm.Symbol,
129+
# TODO : free symbols should be able to be passed in to save time in
130+
# recomputing
123131
exclude=expr.free_symbols,
124132
real=True,
125133
)
@@ -158,7 +166,7 @@ def add_to_cache(node):
158166
start = timer()
159167
zeros = sm.ImmutableDenseMatrix.zeros(1, len(wrt))
160168
for symbol, subexpr in replacements:
161-
free_symbols = subexpr.free_symbols
169+
free_symbols = sm.ordered(subexpr.free_symbols)
162170
absolute_derivative = zeros
163171
for free_symbol in free_symbols:
164172
replacement_symbol, partial_derivative = add_to_cache(
@@ -182,8 +190,8 @@ def add_to_cache(node):
182190
entry = stack.pop()
183191
if entry in required_replacement_symbols or entry in wrt:
184192
continue
185-
children = list(
186-
replacement_to_reduced_expr_cache.get(entry, entry).free_symbols)
193+
children = list(sm.ordered(
194+
replacement_to_reduced_expr_cache.get(entry, entry).free_symbols))
187195
for child in children:
188196
if child not in required_replacement_symbols and child not in wrt:
189197
stack.append(child)
@@ -198,9 +206,9 @@ def add_to_cache(node):
198206
if replacement_symbol in required_replacement_symbols
199207
}
200208

201-
counter = Counter(replaced_jacobian.free_symbols)
209+
counter = Counter(sm.ordered(replaced_jacobian.free_symbols))
202210
for replaced_subexpr in required_replacements_dense.values():
203-
counter.update(replaced_subexpr.free_symbols)
211+
counter.update(sm.ordered(replaced_subexpr.free_symbols))
204212

205213
logger.info('Substituting required replacements...')
206214
required_replacements = {}
@@ -459,6 +467,7 @@ def sort_sympy(seq):
459467

460468

461469
_c_template = """\
470+
// opty_code_hash={eval_code_hash}
462471
{win_math_def}
463472
#include <math.h>
464473
#include "{file_prefix}_h.h"
@@ -620,34 +629,38 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
620629
args : iterable of sympy.Symbol
621630
A list of all symbols in expr in the desired order for the output
622631
function.
623-
expr : sympy.Matrix
624-
A matrix of expressions.
632+
expr : sympy.Matrix or 2-tuple
633+
A matrix of expressions or the output of ``cse()`` of a matrix of
634+
expressions.
625635
const : tuple, optional
626-
This should include any of the symbols in args that should be
627-
constant with respect to the loop.
636+
This should include any of the symbols in args that should be constant
637+
with respect to the evaluation loop.
628638
tmp_dir : string, optional
629-
The path to a directory in which to store the generated files. If
630-
None then the files will be not be retained after the function is
631-
compiled.
639+
The path to a directory in which to store the generated files. If None
640+
then the files will be not be retained after the function is compiled.
641+
If this temporary directory is set to an existing populated directory
642+
and ``expr`` has not changed relative to prior executions of this
643+
function, the compilation step will be skipped if equivalent compiled
644+
modules are already present and cached.
632645
parallel : boolean, optional
633646
If True and openmp is installed, the generated code will be
634-
parallelized across threads. This is only useful when expr are
647+
parallelized across threads. This is only useful when ``expr`` are
635648
extremely large.
636649
show_compile_output : boolean, optional
637650
If True, STDOUT and STDERR of the Cython compilation call will be
638651
shown.
639652
640653
"""
641654

642-
# TODO : This is my first ever global variable in Python. It'd probably
643-
# be better if this was a class attribute of a Ufuncifier class. And I'm
644-
# not sure if this current version counts sequentially.
655+
# TODO : This is my first ever global variable in Python. It'd probably be
656+
# better if this was a class attribute of a Ufuncifier class. And I'm not
657+
# sure if this current version counts sequentially.
645658
global module_counter
646659

647660
if hasattr(expr, 'shape'):
648661
num_rows = expr.shape[0]
649662
num_cols = expr.shape[1]
650-
else:
663+
else: # output of cse()
651664
num_rows = expr[1][0].shape[0]
652665
num_cols = expr[1][0].shape[1]
653666

@@ -675,6 +688,8 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
675688
file_prefix = '{}_{}'.format(file_prefix_base, module_counter)
676689
module_counter += 1
677690

691+
prior_module_number = module_counter - 1
692+
678693
d = {'routine_name': 'eval_matrix',
679694
'file_prefix': file_prefix,
680695
'matrix_output_size': matrix_size,
@@ -723,6 +738,19 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
723738
d['eval_code'] = ' ' + '\n '.join((sub_expr_code + '\n' +
724739
matrix_code).split('\n'))
725740

741+
# NOTE : It is very unlikely that the contents of evaluation code can be
742+
# identical for two different sets of differential equations, so we hash it
743+
# and store the hash value in the C file that contains the evaluation code.
744+
# TODO : Maybe we should only do this if tmp_dir is not None, as it could
745+
# have an undesired computational cost.
746+
logger.debug('Calculating cache hash.')
747+
hasher = hashlib.new('sha256')
748+
const_str = 'const=None' if const is None else 'const={}'.format(const)
749+
parallel_str = 'parallel={}'.format(parallel)
750+
hasher.update((const_str + parallel_str + d['eval_code']).encode())
751+
d['eval_code_hash'] = str(hasher.hexdigest())
752+
logger.debug('Done calculating cache hash: {}'.format(d['eval_code_hash']))
753+
726754
c_indent = len('void {routine_name}('.format(**d))
727755
c_arg_spacer = ',\n' + ' ' * c_indent
728756

@@ -734,6 +762,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
734762
memory_views = []
735763
for a in args:
736764
if const is not None and a in const:
765+
# TODO : Should these be declared const in C?
737766
typ = 'double'
738767
idexy = '{}'
739768
cython_input_args.append('{} {}'.format(typ, ccode(a)))
@@ -772,6 +801,42 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
772801

773802
workingdir = os.getcwd()
774803
os.chdir(codedir)
804+
logger.info('Changed directory to {}'.format(codedir))
805+
806+
# NOTE : If there are other files present in the directory (will only occur
807+
# if a tmp_dir is set) then search through them starting with the most
808+
# recent and see if it has a matching hash to the evaluation code generated
809+
# here. If a match is found, store the module number.
810+
matching_module_num = None
811+
for prior_num in reversed(range(prior_module_number + 1)):
812+
old_file_prefix = '{}_{}'.format(file_prefix_base, prior_num)
813+
logger.info(f'Checking {old_file_prefix} for cached code.')
814+
try:
815+
with open(old_file_prefix + '_c.c', 'r') as f:
816+
hash_line = f.readline()
817+
logger.debug(hash_line.strip())
818+
if 'opty_code_hash={}'.format(d['eval_code_hash']) in hash_line:
819+
matching_module_num = prior_num
820+
logger.info(f'{old_file_prefix} matches!')
821+
break
822+
except FileNotFoundError:
823+
logger.debug(f'{old_file_prefix} not found.')
824+
pass
825+
826+
# NOTE : If we found a matching C file, then try to simply load that module
827+
# instead of compiling a new one. This lets us skip the compile step if we
828+
# have not changed anything about the model.
829+
if matching_module_num is not None:
830+
try:
831+
cython_module = importlib.import_module(old_file_prefix)
832+
except ImportError: # false positive, so compile a new one
833+
logger.info(f'Failed to import {old_file_prefix}.')
834+
pass
835+
else:
836+
logger.info(f'Skipped compile, {old_file_prefix} module loaded.')
837+
os.chdir(workingdir)
838+
logger.info(f'Changed directory to {workingdir}.')
839+
return getattr(cython_module, d['routine_name'] + '_loop')
775840

776841
try:
777842
sys.path.append(codedir)
@@ -797,6 +862,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
797862
else:
798863
encoding = None
799864
try:
865+
logger.info('Compiling matrix evaluation.')
800866
proc = subprocess.run(cmd, capture_output=True, text=True,
801867
encoding=encoding)
802868
# On Windows this can raise a UnicodeDecodeError, but only in the
@@ -811,8 +877,13 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
811877
if show_compile_output:
812878
print(stdout)
813879
print(stderr)
880+
else:
881+
logger.debug(stdout)
882+
logger.debug(stderr)
883+
814884
try:
815885
cython_module = importlib.import_module(d['file_prefix'])
886+
logger.info("Loaded {} module".format(d['file_prefix']))
816887
except ImportError as error:
817888
msg = ('Unable to import the compiled Cython module {}, '
818889
'compilation likely failed. STDERR output from '
@@ -827,6 +898,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
827898
# so I don't delete the directory on Windows.
828899
if sys.platform != "win32":
829900
shutil.rmtree(codedir)
901+
logger.info('Removed directory {}'.format(codedir))
830902

831903
return getattr(cython_module, d['routine_name'] + '_loop')
832904

0 commit comments

Comments
 (0)