|
1 | 1 | import inspect
|
2 | 2 | from collections.abc import Callable
|
3 | 3 | from typing import Literal, TypeVar
|
4 |
| - |
5 | 4 | from jax import Array, vmap
|
6 |
| - |
| 5 | +from jax import numpy as jnp |
| 6 | +from jax.lax import map |
7 | 7 | from lcm.functools import allow_args, allow_only_kwargs
|
| 8 | +from lcm.typing import ParamsDict |
8 | 9 |
|
9 | 10 | F = TypeVar("F", bound=Callable[..., Array])
|
10 | 11 |
|
11 | 12 |
|
12 | 13 | def spacemap(
|
13 | 14 | func: F,
|
14 |
| - dense_vars: list[str], |
15 |
| - sparse_vars: list[str], |
16 |
| - *, |
17 |
| - put_dense_first: bool, |
| 15 | + state_and_discrete_vars: dict[str:Array], |
| 16 | + continous_vars: dict[str:Array], |
| 17 | + memory_restriction: int, |
| 18 | + params: ParamsDict, |
| 19 | + vf_arr: Array, |
18 | 20 | ) -> F:
|
19 |
| - """Apply vmap such that func is evaluated on a space of dense and sparse variables. |
| 21 | + """ |
| 22 | + Evaluate func along all state and discrete choice axes in a way that reduces the memory usage |
| 23 | + below a preset value, trying to find a balance between speed and memory usage. |
| 24 | +
|
| 25 | + If the dimension of the state-choice space fits into memory this will be the same as vmapping along all axes. |
| 26 | + Otherwise the computation along the outermost state axes will be serialized until the remeining problem fits |
| 27 | + into memory. |
20 | 28 |
|
21 |
| - This is achieved by applying _base_productmap for all dense variables and vmap_1d |
22 |
| - for the sparse variables. |
| 29 | + This only works if the continous part of the model fits into memory, as this part has already been vmapped in func. |
| 30 | + To serialize parts of the continous axes one would have to write a more complicated function using scan that |
| 31 | + replicates the behaviour of map with the batch size parameter. For models with many continous axes there |
| 32 | + might be better ways to find the maximum along those axes than evaluating func on all points. |
23 | 33 |
|
24 |
| - spacemap preserves the function signature and allows the function to be called with |
25 |
| - keyword arguments. |
26 | 34 |
|
27 | 35 | Args:
|
28 | 36 | func: The function to be dispatched.
|
29 |
| - dense_vars: Names of the dense variables, i.e. those that are stored as arrays |
30 |
| - of possible values in the grid. |
31 |
| - sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays |
32 |
| - of possible combinations of variables in the grid. |
33 |
| - put_dense_first: Whether the dense or sparse dimensions should come first in the |
34 |
| - output of the dispatched function. |
| 37 | + state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis. |
| 38 | + continous_vars: Dict of names and values for each continous choice axis. |
| 39 | + memory_restriction: Maximum allowed memory usage of the vmap in Bytes. Maybe the user should be able to set this. |
| 40 | + Could also be grabbed through jax, but then we would have to set the limit very low |
| 41 | + and users would not be able to overwrite it for better performance. |
| 42 | + params: Parameters of the Model. |
| 43 | + vf_arr: Discretized Value Function from previous period. |
35 | 44 |
|
36 | 45 |
|
37 | 46 | Returns:
|
38 |
| - A callable with the same arguments as func (but with an additional leading |
39 |
| - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func`` |
40 |
| - returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1 |
41 |
| - jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars`` |
42 |
| - and the additional dimension corresponds to the ``sparse_vars``. The order of |
43 |
| - the dimensions is determined by the order of ``dense_vars`` as well as the |
44 |
| - ``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the |
45 |
| - usual jax behavior applies, i.e. the leading dimensions of all arrays in the |
46 |
| - pytree are as described above but there might be additional dimensions. |
| 47 | + A callable that evaluates func along the provided dicrete choice and state axes. |
47 | 48 |
|
48 | 49 | """
|
49 | 50 | # Check inputs and prepare function
|
50 | 51 | # ==================================================================================
|
51 |
| - overlap = set(dense_vars).intersection(sparse_vars) |
52 |
| - if overlap: |
53 |
| - raise ValueError( |
54 |
| - f"Dense and sparse variables must be disjoint. Overlap: {overlap}", |
55 |
| - ) |
56 |
| - |
57 |
| - duplicates = {v for v in dense_vars if dense_vars.count(v) > 1} |
58 |
| - if duplicates: |
59 |
| - raise ValueError( |
60 |
| - f"Same argument provided more than once in dense variables: {duplicates}", |
61 |
| - ) |
62 |
| - |
63 |
| - duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1} |
64 |
| - if duplicates: |
65 |
| - raise ValueError( |
66 |
| - f"Same argument provided more than once in sparse variables: {duplicates}", |
67 |
| - ) |
68 |
| - |
| 52 | + |
69 | 53 | # jax.vmap cannot deal with keyword-only arguments
|
70 | 54 | func = allow_args(func)
|
71 |
| - |
72 |
| - # Apply vmap_1d for sparse and _base_productmap for dense variables |
73 |
| - # ================================================================================== |
74 |
| - if not sparse_vars: |
75 |
| - vmapped = _base_productmap(func, dense_vars) |
76 |
| - elif put_dense_first: |
77 |
| - vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args") |
78 |
| - vmapped = _base_productmap(vmapped, dense_vars) |
79 |
| - else: |
80 |
| - vmapped = _base_productmap(func, dense_vars) |
81 |
| - vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args") |
82 |
| - |
83 |
| - # This raises a mypy error but is perfectly fine to do. See |
84 |
| - # https://github.com/python/mypy/issues/12472 |
85 |
| - vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined] |
86 |
| - |
87 |
| - return allow_only_kwargs(vmapped) |
| 55 | + |
| 56 | + # I removed duplicate and overlap checks because we are passing dicts now |
| 57 | + # and overlap between state+dicrte and continous seems unlikely |
| 58 | + |
| 59 | + # Set the batch size parameters for the stacked maps, controlling the degree of serialization. |
| 60 | + # Checks if vmapping along given axis is possible, starting from the innermost discrete choice axis. |
| 61 | + # If vmapping is possible batch size is set to len(axis) because vamp=map if batch size=len(axis), |
| 62 | + # otherwise batchsize is set to 1 serializing the evaluation of func along this axis. |
| 63 | + |
| 64 | + memory_strat = {} |
| 65 | + memory_restriction = (memory_restriction/4)/2 |
| 66 | + for key in continous_vars.keys(): |
| 67 | + memory_restriction = memory_restriction / jnp.size(continous_vars[key]) |
| 68 | + for key in state_and_discrete_vars.keys(): |
| 69 | + memory_restriction = memory_restriction/jnp.size(state_and_discrete_vars[key]) |
| 70 | + if memory_restriction > 1: |
| 71 | + memory_strat[key] = jnp.size(state_and_discrete_vars[key]) |
| 72 | + else: |
| 73 | + memory_strat[key] = 1 |
| 74 | + |
| 75 | + mapped = _base_productmap_map(func, state_and_discrete_vars, continous_vars, memory_strat,params, vf_arr) |
| 76 | + |
| 77 | + return mapped |
88 | 78 |
|
89 | 79 |
|
90 | 80 | def vmap_1d(
|
@@ -199,7 +189,35 @@ def productmap(func: F, variables: list[str]) -> F:
|
199 | 189 | return allow_only_kwargs(vmapped)
|
200 | 190 |
|
201 | 191 |
|
202 |
| -def _base_productmap(func: F, product_axes: list[str]) -> F: |
| 192 | +def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], continous_vars: dict[str:Array], strat, params, vf_arr) -> F: |
| 193 | + """Map func over the Cartesian product of state_and_discrete_vars. |
| 194 | +
|
| 195 | + All arguments needed for the evaluation of func are passed via keyword args upon creation, |
| 196 | + the returned callable takes no arguments. |
| 197 | +
|
| 198 | + Args: |
| 199 | + func: The function to be dispatched. |
| 200 | + state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis. |
| 201 | + continous_vars: Dict of names and values for each continous choice axis. |
| 202 | + params: Parameters of the Model. |
| 203 | + vf_arr: Discretized Value Function from previous period. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + A callable that maps func over the provided axes. |
| 207 | +
|
| 208 | + """ |
| 209 | + mapped = lambda **vals : func(**continous_vars, vf_arr = vf_arr, params = params, **vals) |
| 210 | + def stack_maps(func, var, axis): |
| 211 | + def one_more(**xs): |
| 212 | + return map(lambda x_i: func(**xs, **{var:x_i}), axis, batch_size=strat[var]) |
| 213 | + return one_more |
| 214 | + for key,value in reversed(state_and_discrete_vars.items()): |
| 215 | + mapped = stack_maps(mapped,key,value) |
| 216 | + |
| 217 | + |
| 218 | + return mapped |
| 219 | + |
| 220 | +def _base_productmap(func: F, product_axes: list[Array]) -> F: |
203 | 221 | """Map func over the Cartesian product of product_axes.
|
204 | 222 |
|
205 | 223 | Like vmap, this function does not preserve the function signature and does not allow
|
|
0 commit comments