Skip to content

Commit 72768b6

Browse files
committed
Add continuous scan and parallelization
1 parent 278e425 commit 72768b6

File tree

3 files changed

+122
-33
lines changed

3 files changed

+122
-33
lines changed

src/lcm/dispatchers.py

Lines changed: 112 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import inspect
2+
import functools
3+
import jax
24
from collections.abc import Callable
5+
from functools import partial
36
from typing import Literal, TypeVar
47
from jax import Array, vmap
58
from jax import numpy as jnp
6-
from jax.lax import map
9+
from jax.lax import map, concatenate, scan
710
from lcm.functools import allow_args, allow_only_kwargs
811
from lcm.typing import ParamsDict
12+
from jax.sharding import Mesh, PartitionSpec as P
13+
from jax.experimental.shard_map import shard_map
914

1015
F = TypeVar("F", bound=Callable[..., Array])
1116

@@ -15,8 +20,6 @@ def spacemap(
1520
state_and_discrete_vars: dict[str:Array],
1621
continous_vars: dict[str:Array],
1722
memory_restriction: int,
18-
params: ParamsDict,
19-
vf_arr: Array,
2023
) -> F:
2124
"""
2225
Evaluate func along all state and discrete choice axes in a way that reduces the memory usage
@@ -47,11 +50,6 @@ def spacemap(
4750
A callable that evaluates func along the provided dicrete choice and state axes.
4851
4952
"""
50-
# Check inputs and prepare function
51-
# ==================================================================================
52-
53-
# jax.vmap cannot deal with keyword-only arguments
54-
func = allow_args(func)
5553

5654
# I removed duplicate and overlap checks because we are passing dicts now
5755
# and overlap between state+dicrte and continous seems unlikely
@@ -71,10 +69,24 @@ def spacemap(
7169
memory_strat[key] = jnp.size(state_and_discrete_vars[key])
7270
else:
7371
memory_strat[key] = 1
74-
75-
mapped = _base_productmap_map(func, state_and_discrete_vars, continous_vars, memory_strat,params, vf_arr)
7672

77-
return mapped
73+
# Get the last entry, this should be a cont. state var.
74+
# This will be the axis we paralellize over multiple devices.
75+
*_, paralell_axis_name = state_and_discrete_vars.keys()
76+
paralell_axis = state_and_discrete_vars[paralell_axis_name]
77+
state_and_discrete_vars.pop(paralell_axis_name)
78+
# Construct function for axes that are not paralellized over devices
79+
mapped = _base_productmap_map(func, state_and_discrete_vars, memory_strat)
80+
# This is number of devices, could be settable by user
81+
# Does not work if len(axis) not dividable by num. of devices
82+
num_of_devices = jax.device_count()
83+
mesh = jax.make_mesh((num_of_devices,), ('x',))
84+
# Create function so we can use shard_map
85+
def mapped_paralell(x, params,vf_arr ):
86+
res = map(lambda x_i: mapped(**{'params':params, 'vf_arr': vf_arr,paralell_axis_name:x_i}), x, batch_size=memory_strat[paralell_axis_name])
87+
return res
88+
sharded_mapped = shard_map(mapped_paralell,mesh,in_specs=(P('x'),P(),P()),out_specs=P('x'))
89+
return partial(sharded_mapped, paralell_axis)
7890

7991

8092
def vmap_1d(
@@ -146,7 +158,6 @@ def vmap_1d(
146158

147159
return out
148160

149-
150161
def productmap(func: F, variables: list[str]) -> F:
151162
"""Apply vmap such that func is evaluated on the Cartesian product of variables.
152163
@@ -187,9 +198,57 @@ def productmap(func: F, variables: list[str]) -> F:
187198
vmapped.__signature__ = signature # type: ignore[attr-defined]
188199

189200
return allow_only_kwargs(vmapped)
201+
def productmap_scan(func: F, variables: list[str]) -> F:
202+
"""Apply the function to the cartesian product of variables.
203+
204+
If it is possible to apply vmap while respecting the memory restriction,
205+
use stacked vmaps, otherwise start scanning over the variable axes.
206+
207+
Args:
208+
func: The function to be dispatched.
209+
variables: List with names of arguments that over which the Cartesian product
210+
should be formed.
211+
212+
Returns:
213+
A callable with the same arguments as func (but with an additional leading
214+
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
215+
returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
216+
dimensions, where k is the length of ``variables``. The order of the dimensions
217+
is determined by the order of ``variables`` which can be different to the order
218+
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
219+
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
220+
described above but there might be additional dimensions.
190221
222+
"""
191223

192-
def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], continous_vars: dict[str:Array], strat, params, vf_arr) -> F:
224+
# Calculate memory restriction
225+
memory_restriction = 8*(10**9)
226+
vmapped_vars = {}
227+
scanned_vars = {}
228+
for key in variables.keys():
229+
memory_restriction = memory_restriction / jnp.size(variables[key])
230+
if memory_restriction > 1:
231+
vmapped_vars[key] = variables[key]
232+
else:
233+
scanned_vars[key] = variables[key]
234+
# Apply as many vmaps as possible
235+
func = _base_productmap(func=func, product_axes=list(vmapped_vars))
236+
func = allow_only_kwargs(func)
237+
# Create ccv function, so we get the max of vmapped vars
238+
@functools.wraps(func)
239+
def compute_ccv(*args,**kwargs):
240+
u,f = func(*args,**kwargs)
241+
return u.max(where=f, initial=-jnp.inf)
242+
partialled_compute_ccv = lambda *args,**kwargs: compute_ccv(*args,**kwargs,**vmapped_vars )
243+
# Aplly scan for all vars that could not be vmapped
244+
if scanned_vars:
245+
mapped = _base_productmap_scan(partialled_compute_ccv, scanned_vars)
246+
return mapped
247+
248+
return partialled_compute_ccv
249+
250+
251+
def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], strat) -> F:
193252
"""Map func over the Cartesian product of state_and_discrete_vars.
194253
195254
All arguments needed for the evaluation of func are passed via keyword args upon creation,
@@ -206,17 +265,55 @@ def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], cont
206265
A callable that maps func over the provided axes.
207266
208267
"""
209-
mapped = lambda **vals : func(**continous_vars, vf_arr = vf_arr, params = params, **vals)
268+
mapped = func
210269
def stack_maps(func, var, axis):
211270
def one_more(**xs):
212271
return map(lambda x_i: func(**xs, **{var:x_i}), axis, batch_size=strat[var])
213272
return one_more
273+
214274
for key,value in reversed(state_and_discrete_vars.items()):
215275
mapped = stack_maps(mapped,key,value)
216276

217-
218277
return mapped
219278

279+
280+
def _base_productmap_scan(func: F, continous_vars: dict[str:Array]) -> F:
281+
"""Map func over the Cartesian product of state_and_discrete_vars.
282+
283+
All arguments needed for the evaluation of func are passed via keyword args upon creation,
284+
the returned callable takes no arguments.
285+
286+
Args:
287+
func: The function to be dispatched.
288+
state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
289+
continous_vars: Dict of names and values for each continous choice axis.
290+
params: Parameters of the Model.
291+
vf_arr: Discretized Value Function from previous period.
292+
293+
Returns:
294+
A callable that maps func over the provided axes.
295+
296+
"""
297+
# We want the maximum of all the values in the scan
298+
def loop(s, **x_is):
299+
u = func(**x_is)
300+
return jnp.maximum(u, s)
301+
302+
# induction case: scan over one argument, eliminating it
303+
def scan_one_more(loop, var,axis):
304+
def new_loop(s, **xs):
305+
s, _ = scan(lambda s, x_i: (loop(s, **xs, **{var:x_i}), None), s, axis)
306+
return s
307+
return new_loop
308+
309+
# compose
310+
for key,value in continous_vars.items():
311+
loop = scan_one_more(loop, key,value)
312+
313+
return partial(loop, s=-jnp.inf)
314+
315+
316+
220317
def _base_productmap(func: F, product_axes: list[Array]) -> F:
221318
"""Map func over the Cartesian product of product_axes.
222319
@@ -235,7 +332,6 @@ def _base_productmap(func: F, product_axes: list[Array]) -> F:
235332
parameters = list(signature.parameters)
236333

237334
positions = [parameters.index(ax) for ax in product_axes]
238-
239335
vmap_specs = []
240336
# We iterate in reverse order such that the output dimensions are in the same order
241337
# as the input dimensions.

src/lcm/entry_point.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from lcm.argmax import argmax
1010
from lcm.discrete_problem import get_solve_discrete_problem
11-
from lcm.dispatchers import productmap
11+
from lcm.dispatchers import productmap_scan, productmap
1212
from lcm.input_processing import process_model
1313
from lcm.logging import get_logger
1414
from lcm.model_functions import (
@@ -134,7 +134,7 @@ def get_lcm_function(
134134

135135
compute_ccv = create_compute_conditional_continuation_value(
136136
utility_and_feasibility=u_and_f,
137-
continuous_choice_variables=list(_choice_grids),
137+
continuous_choice_variables=_choice_grids,
138138
)
139139
compute_ccv_functions.append(compute_ccv)
140140

@@ -214,17 +214,11 @@ def create_compute_conditional_continuation_value(
214214
215215
"""
216216
if continuous_choice_variables:
217-
utility_and_feasibility = productmap(
217+
utility_and_feasibility = productmap_scan(
218218
func=utility_and_feasibility,
219219
variables=continuous_choice_variables,
220220
)
221-
222-
@functools.wraps(utility_and_feasibility)
223-
def compute_ccv(*args, **kwargs):
224-
u, f = utility_and_feasibility(*args, **kwargs)
225-
return u.max(where=f, initial=-jnp.inf)
226-
227-
return compute_ccv
221+
return utility_and_feasibility
228222

229223

230224
def create_compute_conditional_continuation_policy(

src/lcm/solve_brute.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import jax
2-
2+
import nvtx
33
from lcm.dispatchers import spacemap
44

55

@@ -54,8 +54,8 @@ def solve(
5454
vf_arr = None
5555

5656
logger.info("Starting solution")
57-
5857

58+
5959
# backwards induction loop
6060
for period in reversed(range(n_periods)):
6161
# solve continuous problem, conditional on discrete choices
@@ -67,7 +67,6 @@ def solve(
6767
state_indexers=state_indexers[period],
6868
params=params,
6969
)
70-
7170
# solve discrete problem by calculating expected maximum over discrete choices
7271
calculate_emax = emax_calculators[period]
7372
vf_arr = calculate_emax(conditional_continuation_values, params=params)
@@ -116,9 +115,9 @@ def solve_continuous_problem(
116115
state_and_discrete_vars=state_choice_space.dense_vars,
117116
continous_vars=continuous_choice_grids,
118117
memory_restriction=8*(10**9),
119-
vf_arr=vf_arr,
120-
params=params
118+
121119
)
122120
gridmapped = jax.jit(_gridmapped)
123-
124-
return gridmapped()
121+
res = gridmapped(params, vf_arr)
122+
res = jax.numpy.moveaxis(res,(0),(-1))
123+
return res

0 commit comments

Comments
 (0)