Skip to content

Commit

Permalink
Add continuous scan and parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Feb 4, 2025
1 parent 278e425 commit 72768b6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 33 deletions.
128 changes: 112 additions & 16 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import inspect
import functools
import jax
from collections.abc import Callable
from functools import partial
from typing import Literal, TypeVar
from jax import Array, vmap
from jax import numpy as jnp
from jax.lax import map
from jax.lax import map, concatenate, scan
from lcm.functools import allow_args, allow_only_kwargs
from lcm.typing import ParamsDict
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

F = TypeVar("F", bound=Callable[..., Array])

Expand All @@ -15,8 +20,6 @@ def spacemap(
state_and_discrete_vars: dict[str:Array],
continous_vars: dict[str:Array],
memory_restriction: int,
params: ParamsDict,
vf_arr: Array,
) -> F:
"""
Evaluate func along all state and discrete choice axes in a way that reduces the memory usage
Expand Down Expand Up @@ -47,11 +50,6 @@ def spacemap(
A callable that evaluates func along the provided dicrete choice and state axes.
"""
# Check inputs and prepare function
# ==================================================================================

# jax.vmap cannot deal with keyword-only arguments
func = allow_args(func)

# I removed duplicate and overlap checks because we are passing dicts now
# and overlap between state+dicrte and continous seems unlikely
Expand All @@ -71,10 +69,24 @@ def spacemap(
memory_strat[key] = jnp.size(state_and_discrete_vars[key])
else:
memory_strat[key] = 1

mapped = _base_productmap_map(func, state_and_discrete_vars, continous_vars, memory_strat,params, vf_arr)

return mapped
# Get the last entry, this should be a cont. state var.
# This will be the axis we paralellize over multiple devices.
*_, paralell_axis_name = state_and_discrete_vars.keys()
paralell_axis = state_and_discrete_vars[paralell_axis_name]
state_and_discrete_vars.pop(paralell_axis_name)
# Construct function for axes that are not paralellized over devices
mapped = _base_productmap_map(func, state_and_discrete_vars, memory_strat)
# This is number of devices, could be settable by user
# Does not work if len(axis) not dividable by num. of devices
num_of_devices = jax.device_count()
mesh = jax.make_mesh((num_of_devices,), ('x',))
# Create function so we can use shard_map
def mapped_paralell(x, params,vf_arr ):
res = map(lambda x_i: mapped(**{'params':params, 'vf_arr': vf_arr,paralell_axis_name:x_i}), x, batch_size=memory_strat[paralell_axis_name])
return res
sharded_mapped = shard_map(mapped_paralell,mesh,in_specs=(P('x'),P(),P()),out_specs=P('x'))
return partial(sharded_mapped, paralell_axis)


def vmap_1d(
Expand Down Expand Up @@ -146,7 +158,6 @@ def vmap_1d(

return out


def productmap(func: F, variables: list[str]) -> F:
"""Apply vmap such that func is evaluated on the Cartesian product of variables.
Expand Down Expand Up @@ -187,9 +198,57 @@ def productmap(func: F, variables: list[str]) -> F:
vmapped.__signature__ = signature # type: ignore[attr-defined]

return allow_only_kwargs(vmapped)
def productmap_scan(func: F, variables: list[str]) -> F:
"""Apply the function to the cartesian product of variables.
If it is possible to apply vmap while respecting the memory restriction,
use stacked vmaps, otherwise start scanning over the variable axes.
Args:
func: The function to be dispatched.
variables: List with names of arguments that over which the Cartesian product
should be formed.
Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
dimensions, where k is the length of ``variables``. The order of the dimensions
is determined by the order of ``variables`` which can be different to the order
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.
"""

def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], continous_vars: dict[str:Array], strat, params, vf_arr) -> F:
# Calculate memory restriction
memory_restriction = 8*(10**9)
vmapped_vars = {}
scanned_vars = {}
for key in variables.keys():
memory_restriction = memory_restriction / jnp.size(variables[key])
if memory_restriction > 1:
vmapped_vars[key] = variables[key]
else:
scanned_vars[key] = variables[key]
# Apply as many vmaps as possible
func = _base_productmap(func=func, product_axes=list(vmapped_vars))
func = allow_only_kwargs(func)
# Create ccv function, so we get the max of vmapped vars
@functools.wraps(func)
def compute_ccv(*args,**kwargs):
u,f = func(*args,**kwargs)
return u.max(where=f, initial=-jnp.inf)
partialled_compute_ccv = lambda *args,**kwargs: compute_ccv(*args,**kwargs,**vmapped_vars )
# Aplly scan for all vars that could not be vmapped
if scanned_vars:
mapped = _base_productmap_scan(partialled_compute_ccv, scanned_vars)
return mapped

return partialled_compute_ccv


def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], strat) -> F:
"""Map func over the Cartesian product of state_and_discrete_vars.
All arguments needed for the evaluation of func are passed via keyword args upon creation,
Expand All @@ -206,17 +265,55 @@ def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], cont
A callable that maps func over the provided axes.
"""
mapped = lambda **vals : func(**continous_vars, vf_arr = vf_arr, params = params, **vals)
mapped = func
def stack_maps(func, var, axis):
def one_more(**xs):
return map(lambda x_i: func(**xs, **{var:x_i}), axis, batch_size=strat[var])
return one_more

for key,value in reversed(state_and_discrete_vars.items()):
mapped = stack_maps(mapped,key,value)


return mapped


def _base_productmap_scan(func: F, continous_vars: dict[str:Array]) -> F:
"""Map func over the Cartesian product of state_and_discrete_vars.
All arguments needed for the evaluation of func are passed via keyword args upon creation,
the returned callable takes no arguments.
Args:
func: The function to be dispatched.
state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
continous_vars: Dict of names and values for each continous choice axis.
params: Parameters of the Model.
vf_arr: Discretized Value Function from previous period.
Returns:
A callable that maps func over the provided axes.
"""
# We want the maximum of all the values in the scan
def loop(s, **x_is):
u = func(**x_is)
return jnp.maximum(u, s)

# induction case: scan over one argument, eliminating it
def scan_one_more(loop, var,axis):
def new_loop(s, **xs):
s, _ = scan(lambda s, x_i: (loop(s, **xs, **{var:x_i}), None), s, axis)
return s
return new_loop

# compose
for key,value in continous_vars.items():
loop = scan_one_more(loop, key,value)

return partial(loop, s=-jnp.inf)



def _base_productmap(func: F, product_axes: list[Array]) -> F:
"""Map func over the Cartesian product of product_axes.
Expand All @@ -235,7 +332,6 @@ def _base_productmap(func: F, product_axes: list[Array]) -> F:
parameters = list(signature.parameters)

positions = [parameters.index(ax) for ax in product_axes]

vmap_specs = []
# We iterate in reverse order such that the output dimensions are in the same order
# as the input dimensions.
Expand Down
14 changes: 4 additions & 10 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from lcm.argmax import argmax
from lcm.discrete_problem import get_solve_discrete_problem
from lcm.dispatchers import productmap
from lcm.dispatchers import productmap_scan, productmap
from lcm.input_processing import process_model
from lcm.logging import get_logger
from lcm.model_functions import (
Expand Down Expand Up @@ -134,7 +134,7 @@ def get_lcm_function(

compute_ccv = create_compute_conditional_continuation_value(
utility_and_feasibility=u_and_f,
continuous_choice_variables=list(_choice_grids),
continuous_choice_variables=_choice_grids,
)
compute_ccv_functions.append(compute_ccv)

Expand Down Expand Up @@ -214,17 +214,11 @@ def create_compute_conditional_continuation_value(
"""
if continuous_choice_variables:
utility_and_feasibility = productmap(
utility_and_feasibility = productmap_scan(
func=utility_and_feasibility,
variables=continuous_choice_variables,
)

@functools.wraps(utility_and_feasibility)
def compute_ccv(*args, **kwargs):
u, f = utility_and_feasibility(*args, **kwargs)
return u.max(where=f, initial=-jnp.inf)

return compute_ccv
return utility_and_feasibility


def create_compute_conditional_continuation_policy(
Expand Down
13 changes: 6 additions & 7 deletions src/lcm/solve_brute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax

import nvtx
from lcm.dispatchers import spacemap


Expand Down Expand Up @@ -54,8 +54,8 @@ def solve(
vf_arr = None

logger.info("Starting solution")



# backwards induction loop
for period in reversed(range(n_periods)):
# solve continuous problem, conditional on discrete choices
Expand All @@ -67,7 +67,6 @@ def solve(
state_indexers=state_indexers[period],
params=params,
)

# solve discrete problem by calculating expected maximum over discrete choices
calculate_emax = emax_calculators[period]
vf_arr = calculate_emax(conditional_continuation_values, params=params)
Expand Down Expand Up @@ -116,9 +115,9 @@ def solve_continuous_problem(
state_and_discrete_vars=state_choice_space.dense_vars,
continous_vars=continuous_choice_grids,
memory_restriction=8*(10**9),
vf_arr=vf_arr,
params=params

)
gridmapped = jax.jit(_gridmapped)

return gridmapped()
res = gridmapped(params, vf_arr)
res = jax.numpy.moveaxis(res,(0),(-1))
return res

0 comments on commit 72768b6

Please sign in to comment.