Skip to content

Commit 278e425

Browse files
committed
Exchange vmap for map
1 parent e95281f commit 278e425

File tree

2 files changed

+87
-73
lines changed

2 files changed

+87
-73
lines changed

src/lcm/dispatchers.py

Lines changed: 80 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,80 @@
11
import inspect
22
from collections.abc import Callable
33
from typing import Literal, TypeVar
4-
54
from jax import Array, vmap
6-
5+
from jax import numpy as jnp
6+
from jax.lax import map
77
from lcm.functools import allow_args, allow_only_kwargs
8+
from lcm.typing import ParamsDict
89

910
F = TypeVar("F", bound=Callable[..., Array])
1011

1112

1213
def spacemap(
1314
func: F,
14-
dense_vars: list[str],
15-
sparse_vars: list[str],
16-
*,
17-
put_dense_first: bool,
15+
state_and_discrete_vars: dict[str:Array],
16+
continous_vars: dict[str:Array],
17+
memory_restriction: int,
18+
params: ParamsDict,
19+
vf_arr: Array,
1820
) -> F:
19-
"""Apply vmap such that func is evaluated on a space of dense and sparse variables.
21+
"""
22+
Evaluate func along all state and discrete choice axes in a way that reduces the memory usage
23+
below a preset value, trying to find a balance between speed and memory usage.
24+
25+
If the dimension of the state-choice space fits into memory this will be the same as vmapping along all axes.
26+
Otherwise the computation along the outermost state axes will be serialized until the remeining problem fits
27+
into memory.
2028
21-
This is achieved by applying _base_productmap for all dense variables and vmap_1d
22-
for the sparse variables.
29+
This only works if the continous part of the model fits into memory, as this part has already been vmapped in func.
30+
To serialize parts of the continous axes one would have to write a more complicated function using scan that
31+
replicates the behaviour of map with the batch size parameter. For models with many continous axes there
32+
might be better ways to find the maximum along those axes than evaluating func on all points.
2333
24-
spacemap preserves the function signature and allows the function to be called with
25-
keyword arguments.
2634
2735
Args:
2836
func: The function to be dispatched.
29-
dense_vars: Names of the dense variables, i.e. those that are stored as arrays
30-
of possible values in the grid.
31-
sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays
32-
of possible combinations of variables in the grid.
33-
put_dense_first: Whether the dense or sparse dimensions should come first in the
34-
output of the dispatched function.
37+
state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
38+
continous_vars: Dict of names and values for each continous choice axis.
39+
memory_restriction: Maximum allowed memory usage of the vmap in Bytes. Maybe the user should be able to set this.
40+
Could also be grabbed through jax, but then we would have to set the limit very low
41+
and users would not be able to overwrite it for better performance.
42+
params: Parameters of the Model.
43+
vf_arr: Discretized Value Function from previous period.
3544
3645
3746
Returns:
38-
A callable with the same arguments as func (but with an additional leading
39-
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
40-
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
41-
jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars``
42-
and the additional dimension corresponds to the ``sparse_vars``. The order of
43-
the dimensions is determined by the order of ``dense_vars`` as well as the
44-
``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the
45-
usual jax behavior applies, i.e. the leading dimensions of all arrays in the
46-
pytree are as described above but there might be additional dimensions.
47+
A callable that evaluates func along the provided dicrete choice and state axes.
4748
4849
"""
4950
# Check inputs and prepare function
5051
# ==================================================================================
51-
overlap = set(dense_vars).intersection(sparse_vars)
52-
if overlap:
53-
raise ValueError(
54-
f"Dense and sparse variables must be disjoint. Overlap: {overlap}",
55-
)
56-
57-
duplicates = {v for v in dense_vars if dense_vars.count(v) > 1}
58-
if duplicates:
59-
raise ValueError(
60-
f"Same argument provided more than once in dense variables: {duplicates}",
61-
)
62-
63-
duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1}
64-
if duplicates:
65-
raise ValueError(
66-
f"Same argument provided more than once in sparse variables: {duplicates}",
67-
)
68-
52+
6953
# jax.vmap cannot deal with keyword-only arguments
7054
func = allow_args(func)
71-
72-
# Apply vmap_1d for sparse and _base_productmap for dense variables
73-
# ==================================================================================
74-
if not sparse_vars:
75-
vmapped = _base_productmap(func, dense_vars)
76-
elif put_dense_first:
77-
vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args")
78-
vmapped = _base_productmap(vmapped, dense_vars)
79-
else:
80-
vmapped = _base_productmap(func, dense_vars)
81-
vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args")
82-
83-
# This raises a mypy error but is perfectly fine to do. See
84-
# https://github.com/python/mypy/issues/12472
85-
vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined]
86-
87-
return allow_only_kwargs(vmapped)
55+
56+
# I removed duplicate and overlap checks because we are passing dicts now
57+
# and overlap between state+dicrte and continous seems unlikely
58+
59+
# Set the batch size parameters for the stacked maps, controlling the degree of serialization.
60+
# Checks if vmapping along given axis is possible, starting from the innermost discrete choice axis.
61+
# If vmapping is possible batch size is set to len(axis) because vamp=map if batch size=len(axis),
62+
# otherwise batchsize is set to 1 serializing the evaluation of func along this axis.
63+
64+
memory_strat = {}
65+
memory_restriction = (memory_restriction/4)/2
66+
for key in continous_vars.keys():
67+
memory_restriction = memory_restriction / jnp.size(continous_vars[key])
68+
for key in state_and_discrete_vars.keys():
69+
memory_restriction = memory_restriction/jnp.size(state_and_discrete_vars[key])
70+
if memory_restriction > 1:
71+
memory_strat[key] = jnp.size(state_and_discrete_vars[key])
72+
else:
73+
memory_strat[key] = 1
74+
75+
mapped = _base_productmap_map(func, state_and_discrete_vars, continous_vars, memory_strat,params, vf_arr)
76+
77+
return mapped
8878

8979

9080
def vmap_1d(
@@ -199,7 +189,35 @@ def productmap(func: F, variables: list[str]) -> F:
199189
return allow_only_kwargs(vmapped)
200190

201191

202-
def _base_productmap(func: F, product_axes: list[str]) -> F:
192+
def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], continous_vars: dict[str:Array], strat, params, vf_arr) -> F:
193+
"""Map func over the Cartesian product of state_and_discrete_vars.
194+
195+
All arguments needed for the evaluation of func are passed via keyword args upon creation,
196+
the returned callable takes no arguments.
197+
198+
Args:
199+
func: The function to be dispatched.
200+
state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
201+
continous_vars: Dict of names and values for each continous choice axis.
202+
params: Parameters of the Model.
203+
vf_arr: Discretized Value Function from previous period.
204+
205+
Returns:
206+
A callable that maps func over the provided axes.
207+
208+
"""
209+
mapped = lambda **vals : func(**continous_vars, vf_arr = vf_arr, params = params, **vals)
210+
def stack_maps(func, var, axis):
211+
def one_more(**xs):
212+
return map(lambda x_i: func(**xs, **{var:x_i}), axis, batch_size=strat[var])
213+
return one_more
214+
for key,value in reversed(state_and_discrete_vars.items()):
215+
mapped = stack_maps(mapped,key,value)
216+
217+
218+
return mapped
219+
220+
def _base_productmap(func: F, product_axes: list[Array]) -> F:
203221
"""Map func over the Cartesian product of product_axes.
204222
205223
Like vmap, this function does not preserve the function signature and does not allow

src/lcm/solve_brute.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def solve(
5454
vf_arr = None
5555

5656
logger.info("Starting solution")
57+
5758

5859
# backwards induction loop
5960
for period in reversed(range(n_periods)):
@@ -112,17 +113,12 @@ def solve_continuous_problem(
112113
"""
113114
_gridmapped = spacemap(
114115
func=compute_ccv,
115-
dense_vars=list(state_choice_space.dense_vars),
116-
sparse_vars=list(state_choice_space.sparse_vars),
117-
put_dense_first=False,
116+
state_and_discrete_vars=state_choice_space.dense_vars,
117+
continous_vars=continuous_choice_grids,
118+
memory_restriction=8*(10**9),
119+
vf_arr=vf_arr,
120+
params=params
118121
)
119122
gridmapped = jax.jit(_gridmapped)
120123

121-
return gridmapped(
122-
**state_choice_space.dense_vars,
123-
**continuous_choice_grids,
124-
**state_choice_space.sparse_vars,
125-
**state_indexers,
126-
vf_arr=vf_arr,
127-
params=params,
128-
)
124+
return gridmapped()

0 commit comments

Comments
 (0)