Skip to content

Commit bc323fc

Browse files
committed
Numpy actx: cache execuctor
1 parent 510dc1b commit bc323fc

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

arraycontext/context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def to_numpy(self,
339339

340340
@abstractmethod
341341
def call_loopy(self,
342-
program: "loopy.TranslationUnit",
342+
t_unit: "loopy.TranslationUnit",
343343
**kwargs: Any) -> Dict[str, Array]:
344344
"""Execute the :mod:`loopy` program *program* on the arguments
345345
*kwargs*.

arraycontext/impl/numpy/__init__.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
14
"""
25
.. currentmodule:: arraycontext
36
@@ -30,7 +33,7 @@
3033
THE SOFTWARE.
3134
"""
3235

33-
from typing import Any, Dict
36+
from typing import Any
3437

3538
import numpy as np
3639

@@ -39,6 +42,7 @@
3942

4043
from arraycontext.container.traversal import rec_map_array_container, with_array_context
4144
from arraycontext.context import (
45+
Array,
4246
ArrayContext,
4347
ArrayOrContainerOrScalar,
4448
ArrayOrContainerOrScalarT,
@@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext):
6266
6367
.. automethod:: __init__
6468
"""
69+
70+
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
71+
6572
def __init__(self) -> None:
6673
super().__init__()
67-
self._loopy_transform_cache: \
68-
Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
74+
self._loopy_transform_cache = {}
6975

7076
array_types = (NumpyNonObjectArray,)
7177

@@ -88,17 +94,18 @@ def to_numpy(self,
8894
) -> NumpyOrContainerOrScalar:
8995
return array
9096

91-
def call_loopy(self, t_unit, **kwargs):
97+
def call_loopy(
98+
self,
99+
t_unit: lp.TranslationUnit, **kwargs: Any
100+
) -> dict[str, Array]:
92101
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
93102
try:
94-
t_unit = self._loopy_transform_cache[t_unit]
103+
executor = self._loopy_transform_cache[t_unit]
95104
except KeyError:
96-
orig_t_unit = t_unit
97-
t_unit = self.transform_loopy_program(t_unit)
98-
self._loopy_transform_cache[orig_t_unit] = t_unit
99-
del orig_t_unit
105+
executor = self.transform_loopy_program(t_unit).executor()
106+
self._loopy_transform_cache[t_unit] = executor
100107

101-
_, result = t_unit(**kwargs)
108+
_, result = executor(**kwargs)
102109

103110
return result
104111

0 commit comments

Comments
 (0)