1515from timeit import default_timer as timer
1616import logging
1717import locale
18+ import hashlib
1819
1920import numpy as np
2021import sympy as sm
@@ -79,6 +80,11 @@ def ccode(expr, assign_to=None, **settings):
7980
8081def _forward_jacobian (expr , wrt ):
8182
83+ # NOTE : free_symbols are sets and are not guaranteed to be in the same
84+ # order, so sympy.ordered() is used throughout to ensure a deterministic
85+ # behavior. This is important for the binary caching to work as it hashes
86+ # the generated code string.
87+
8288 def add_to_cache (node ):
8389 if node in expr_to_replacement_cache :
8490 replacement_symbol = expr_to_replacement_cache [node ]
@@ -120,6 +126,8 @@ def add_to_cache(node):
120126 replacement_symbols = numbered_symbols (
121127 prefix = 'z' ,
122128 cls = sm .Symbol ,
129+ # TODO : free symbols should be able to be passed in to save time in
130+ # recomputing
123131 exclude = expr .free_symbols ,
124132 real = True ,
125133 )
@@ -158,7 +166,7 @@ def add_to_cache(node):
158166 start = timer ()
159167 zeros = sm .ImmutableDenseMatrix .zeros (1 , len (wrt ))
160168 for symbol , subexpr in replacements :
161- free_symbols = subexpr .free_symbols
169+ free_symbols = sm . ordered ( subexpr .free_symbols )
162170 absolute_derivative = zeros
163171 for free_symbol in free_symbols :
164172 replacement_symbol , partial_derivative = add_to_cache (
@@ -182,8 +190,8 @@ def add_to_cache(node):
182190 entry = stack .pop ()
183191 if entry in required_replacement_symbols or entry in wrt :
184192 continue
185- children = list (
186- replacement_to_reduced_expr_cache .get (entry , entry ).free_symbols )
193+ children = list (sm . ordered (
194+ replacement_to_reduced_expr_cache .get (entry , entry ).free_symbols ))
187195 for child in children :
188196 if child not in required_replacement_symbols and child not in wrt :
189197 stack .append (child )
@@ -198,9 +206,9 @@ def add_to_cache(node):
198206 if replacement_symbol in required_replacement_symbols
199207 }
200208
201- counter = Counter (replaced_jacobian .free_symbols )
209+ counter = Counter (sm . ordered ( replaced_jacobian .free_symbols ) )
202210 for replaced_subexpr in required_replacements_dense .values ():
203- counter .update (replaced_subexpr .free_symbols )
211+ counter .update (sm . ordered ( replaced_subexpr .free_symbols ) )
204212
205213 logger .info ('Substituting required replacements...' )
206214 required_replacements = {}
@@ -459,6 +467,7 @@ def sort_sympy(seq):
459467
460468
461469_c_template = """\
470+ // opty_code_hash={eval_code_hash}
462471{win_math_def}
463472#include <math.h>
464473#include "{file_prefix}_h.h"
@@ -620,34 +629,38 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
620629 args : iterable of sympy.Symbol
621630 A list of all symbols in expr in the desired order for the output
622631 function.
623- expr : sympy.Matrix
624- A matrix of expressions.
632+ expr : sympy.Matrix or 2-tuple
633+ A matrix of expressions or the output of ``cse()`` of a matrix of
634+ expressions.
625635 const : tuple, optional
626- This should include any of the symbols in args that should be
627- constant with respect to the loop.
636+ This should include any of the symbols in args that should be constant
637+ with respect to the evaluation loop.
628638 tmp_dir : string, optional
629- The path to a directory in which to store the generated files. If
630- None then the files will be not be retained after the function is
631- compiled.
639+ The path to a directory in which to store the generated files. If None
640+ then the files will be not be retained after the function is compiled.
641+ If this temporary directory is set to an existing populated directory
642+ and ``expr`` has not changed relative to prior executions of this
643+ function, the compilation step will be skipped if equivalent compiled
644+ modules are already present and cached.
632645 parallel : boolean, optional
633646 If True and openmp is installed, the generated code will be
634- parallelized across threads. This is only useful when expr are
647+ parallelized across threads. This is only useful when `` expr`` are
635648 extremely large.
636649 show_compile_output : boolean, optional
637650 If True, STDOUT and STDERR of the Cython compilation call will be
638651 shown.
639652
640653 """
641654
642- # TODO : This is my first ever global variable in Python. It'd probably
643- # be better if this was a class attribute of a Ufuncifier class. And I'm
644- # not sure if this current version counts sequentially.
655+ # TODO : This is my first ever global variable in Python. It'd probably be
656+ # better if this was a class attribute of a Ufuncifier class. And I'm not
657+ # sure if this current version counts sequentially.
645658 global module_counter
646659
647660 if hasattr (expr , 'shape' ):
648661 num_rows = expr .shape [0 ]
649662 num_cols = expr .shape [1 ]
650- else :
663+ else : # output of cse()
651664 num_rows = expr [1 ][0 ].shape [0 ]
652665 num_cols = expr [1 ][0 ].shape [1 ]
653666
@@ -675,6 +688,8 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
675688 file_prefix = '{}_{}' .format (file_prefix_base , module_counter )
676689 module_counter += 1
677690
691+ prior_module_number = module_counter - 1
692+
678693 d = {'routine_name' : 'eval_matrix' ,
679694 'file_prefix' : file_prefix ,
680695 'matrix_output_size' : matrix_size ,
@@ -723,6 +738,19 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
723738 d ['eval_code' ] = ' ' + '\n ' .join ((sub_expr_code + '\n ' +
724739 matrix_code ).split ('\n ' ))
725740
741+ # NOTE : It is very unlikely that the contents of evaluation code can be
742+ # identical for two different sets of differential equations, so we hash it
743+ # and store the hash value in the C file that contains the evaluation code.
744+ # TODO : Maybe we should only do this if tmp_dir is not None, as it could
745+ # have an undesired computational cost.
746+ logger .debug ('Calculating cache hash.' )
747+ hasher = hashlib .new ('sha256' )
748+ const_str = 'const=None' if const is None else 'const={}' .format (const )
749+ parallel_str = 'parallel={}' .format (parallel )
750+ hasher .update ((const_str + parallel_str + d ['eval_code' ]).encode ())
751+ d ['eval_code_hash' ] = str (hasher .hexdigest ())
752+ logger .debug ('Done calculating cache hash: {}' .format (d ['eval_code_hash' ]))
753+
726754 c_indent = len ('void {routine_name}(' .format (** d ))
727755 c_arg_spacer = ',\n ' + ' ' * c_indent
728756
@@ -734,6 +762,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
734762 memory_views = []
735763 for a in args :
736764 if const is not None and a in const :
765+ # TODO : Should these be declared const in C?
737766 typ = 'double'
738767 idexy = '{}'
739768 cython_input_args .append ('{} {}' .format (typ , ccode (a )))
@@ -772,6 +801,42 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
772801
773802 workingdir = os .getcwd ()
774803 os .chdir (codedir )
804+ logger .info ('Changed directory to {}' .format (codedir ))
805+
806+ # NOTE : If there are other files present in the directory (will only occur
807+ # if a tmp_dir is set) then search through them starting with the most
808+ # recent and see if it has a matching hash to the evaluation code generated
809+ # here. If a match is found, store the module number.
810+ matching_module_num = None
811+ for prior_num in reversed (range (prior_module_number + 1 )):
812+ old_file_prefix = '{}_{}' .format (file_prefix_base , prior_num )
813+ logger .info (f'Checking { old_file_prefix } for cached code.' )
814+ try :
815+ with open (old_file_prefix + '_c.c' , 'r' ) as f :
816+ hash_line = f .readline ()
817+ logger .debug (hash_line .strip ())
818+ if 'opty_code_hash={}' .format (d ['eval_code_hash' ]) in hash_line :
819+ matching_module_num = prior_num
820+ logger .info (f'{ old_file_prefix } matches!' )
821+ break
822+ except FileNotFoundError :
823+ logger .debug (f'{ old_file_prefix } not found.' )
824+ pass
825+
826+ # NOTE : If we found a matching C file, then try to simply load that module
827+ # instead of compiling a new one. This lets us skip the compile step if we
828+ # have not changed anything about the model.
829+ if matching_module_num is not None :
830+ try :
831+ cython_module = importlib .import_module (old_file_prefix )
832+ except ImportError : # false positive, so compile a new one
833+ logger .info (f'Failed to import { old_file_prefix } .' )
834+ pass
835+ else :
836+ logger .info (f'Skipped compile, { old_file_prefix } module loaded.' )
837+ os .chdir (workingdir )
838+ logger .info (f'Changed directory to { workingdir } .' )
839+ return getattr (cython_module , d ['routine_name' ] + '_loop' )
775840
776841 try :
777842 sys .path .append (codedir )
@@ -797,6 +862,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
797862 else :
798863 encoding = None
799864 try :
865+ logger .info ('Compiling matrix evaluation.' )
800866 proc = subprocess .run (cmd , capture_output = True , text = True ,
801867 encoding = encoding )
802868 # On Windows this can raise a UnicodeDecodeError, but only in the
@@ -811,8 +877,13 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
811877 if show_compile_output :
812878 print (stdout )
813879 print (stderr )
880+ else :
881+ logger .debug (stdout )
882+ logger .debug (stderr )
883+
814884 try :
815885 cython_module = importlib .import_module (d ['file_prefix' ])
886+ logger .info ("Loaded {} module" .format (d ['file_prefix' ]))
816887 except ImportError as error :
817888 msg = ('Unable to import the compiled Cython module {}, '
818889 'compilation likely failed. STDERR output from '
@@ -827,6 +898,7 @@ def ufuncify_matrix(args, expr, const=None, tmp_dir=None, parallel=False,
827898 # so I don't delete the directory on Windows.
828899 if sys .platform != "win32" :
829900 shutil .rmtree (codedir )
901+ logger .info ('Removed directory {}' .format (codedir ))
830902
831903 return getattr (cython_module , d ['routine_name' ] + '_loop' )
832904
0 commit comments