11import inspect
2+ import functools
3+ import jax
24from collections .abc import Callable
5+ from functools import partial
36from typing import Literal , TypeVar
47from jax import Array , vmap
58from jax import numpy as jnp
6- from jax .lax import map
9+ from jax .lax import map , concatenate , scan
710from lcm .functools import allow_args , allow_only_kwargs
811from lcm .typing import ParamsDict
12+ from jax .sharding import Mesh , PartitionSpec as P
13+ from jax .experimental .shard_map import shard_map
914
1015F = TypeVar ("F" , bound = Callable [..., Array ])
1116
@@ -15,8 +20,6 @@ def spacemap(
1520 state_and_discrete_vars : dict [str :Array ],
1621 continous_vars : dict [str :Array ],
1722 memory_restriction : int ,
18- params : ParamsDict ,
19- vf_arr : Array ,
2023) -> F :
2124 """
2225 Evaluate func along all state and discrete choice axes in a way that reduces the memory usage
@@ -47,11 +50,6 @@ def spacemap(
4750 A callable that evaluates func along the provided dicrete choice and state axes.
4851
4952 """
50- # Check inputs and prepare function
51- # ==================================================================================
52-
53- # jax.vmap cannot deal with keyword-only arguments
54- func = allow_args (func )
5553
5654 # I removed duplicate and overlap checks because we are passing dicts now
5755 # and overlap between state+dicrte and continous seems unlikely
@@ -71,10 +69,24 @@ def spacemap(
7169 memory_strat [key ] = jnp .size (state_and_discrete_vars [key ])
7270 else :
7371 memory_strat [key ] = 1
74-
75- mapped = _base_productmap_map (func , state_and_discrete_vars , continous_vars , memory_strat ,params , vf_arr )
7672
77- return mapped
73+ # Get the last entry, this should be a cont. state var.
74+ # This will be the axis we paralellize over multiple devices.
75+ * _ , paralell_axis_name = state_and_discrete_vars .keys ()
76+ paralell_axis = state_and_discrete_vars [paralell_axis_name ]
77+ state_and_discrete_vars .pop (paralell_axis_name )
78+ # Construct function for axes that are not paralellized over devices
79+ mapped = _base_productmap_map (func , state_and_discrete_vars , memory_strat )
80+ # This is number of devices, could be settable by user
81+ # Does not work if len(axis) not dividable by num. of devices
82+ num_of_devices = jax .device_count ()
83+ mesh = jax .make_mesh ((num_of_devices ,), ('x' ,))
84+ # Create function so we can use shard_map
85+ def mapped_paralell (x , params ,vf_arr ):
86+ res = map (lambda x_i : mapped (** {'params' :params , 'vf_arr' : vf_arr ,paralell_axis_name :x_i }), x , batch_size = memory_strat [paralell_axis_name ])
87+ return res
88+ sharded_mapped = shard_map (mapped_paralell ,mesh ,in_specs = (P ('x' ),P (),P ()),out_specs = P ('x' ))
89+ return partial (sharded_mapped , paralell_axis )
7890
7991
8092def vmap_1d (
@@ -146,7 +158,6 @@ def vmap_1d(
146158
147159 return out
148160
149-
150161def productmap (func : F , variables : list [str ]) -> F :
151162 """Apply vmap such that func is evaluated on the Cartesian product of variables.
152163
@@ -187,9 +198,57 @@ def productmap(func: F, variables: list[str]) -> F:
187198 vmapped .__signature__ = signature # type: ignore[attr-defined]
188199
189200 return allow_only_kwargs (vmapped )
201+ def productmap_scan (func : F , variables : list [str ]) -> F :
202+ """Apply the function to the cartesian product of variables.
203+
204+ If it is possible to apply vmap while respecting the memory restriction,
205+ use stacked vmaps, otherwise start scanning over the variable axes.
206+
207+ Args:
208+ func: The function to be dispatched.
209+ variables: List with names of arguments that over which the Cartesian product
210+ should be formed.
211+
212+ Returns:
213+ A callable with the same arguments as func (but with an additional leading
214+ dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
215+ returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
216+ dimensions, where k is the length of ``variables``. The order of the dimensions
217+ is determined by the order of ``variables`` which can be different to the order
218+ of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
219+ behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
220+ described above but there might be additional dimensions.
190221
222+ """
191223
192- def _base_productmap_map (func : F , state_and_discrete_vars : dict [str :Array ], continous_vars : dict [str :Array ], strat , params , vf_arr ) -> F :
224+ # Calculate memory restriction
225+ memory_restriction = 8 * (10 ** 9 )
226+ vmapped_vars = {}
227+ scanned_vars = {}
228+ for key in variables .keys ():
229+ memory_restriction = memory_restriction / jnp .size (variables [key ])
230+ if memory_restriction > 1 :
231+ vmapped_vars [key ] = variables [key ]
232+ else :
233+ scanned_vars [key ] = variables [key ]
234+ # Apply as many vmaps as possible
235+ func = _base_productmap (func = func , product_axes = list (vmapped_vars ))
236+ func = allow_only_kwargs (func )
237+ # Create ccv function, so we get the max of vmapped vars
238+ @functools .wraps (func )
239+ def compute_ccv (* args ,** kwargs ):
240+ u ,f = func (* args ,** kwargs )
241+ return u .max (where = f , initial = - jnp .inf )
242+ partialled_compute_ccv = lambda * args ,** kwargs : compute_ccv (* args ,** kwargs ,** vmapped_vars )
243+ # Aplly scan for all vars that could not be vmapped
244+ if scanned_vars :
245+ mapped = _base_productmap_scan (partialled_compute_ccv , scanned_vars )
246+ return mapped
247+
248+ return partialled_compute_ccv
249+
250+
251+ def _base_productmap_map (func : F , state_and_discrete_vars : dict [str :Array ], strat ) -> F :
193252 """Map func over the Cartesian product of state_and_discrete_vars.
194253
195254 All arguments needed for the evaluation of func are passed via keyword args upon creation,
@@ -206,17 +265,55 @@ def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], cont
206265 A callable that maps func over the provided axes.
207266
208267 """
209- mapped = lambda ** vals : func ( ** continous_vars , vf_arr = vf_arr , params = params , ** vals )
268+ mapped = func
210269 def stack_maps (func , var , axis ):
211270 def one_more (** xs ):
212271 return map (lambda x_i : func (** xs , ** {var :x_i }), axis , batch_size = strat [var ])
213272 return one_more
273+
214274 for key ,value in reversed (state_and_discrete_vars .items ()):
215275 mapped = stack_maps (mapped ,key ,value )
216276
217-
218277 return mapped
219278
279+
280+ def _base_productmap_scan (func : F , continous_vars : dict [str :Array ]) -> F :
281+ """Map func over the Cartesian product of state_and_discrete_vars.
282+
283+ All arguments needed for the evaluation of func are passed via keyword args upon creation,
284+ the returned callable takes no arguments.
285+
286+ Args:
287+ func: The function to be dispatched.
288+ state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
289+ continous_vars: Dict of names and values for each continous choice axis.
290+ params: Parameters of the Model.
291+ vf_arr: Discretized Value Function from previous period.
292+
293+ Returns:
294+ A callable that maps func over the provided axes.
295+
296+ """
297+ # We want the maximum of all the values in the scan
298+ def loop (s , ** x_is ):
299+ u = func (** x_is )
300+ return jnp .maximum (u , s )
301+
302+ # induction case: scan over one argument, eliminating it
303+ def scan_one_more (loop , var ,axis ):
304+ def new_loop (s , ** xs ):
305+ s , _ = scan (lambda s , x_i : (loop (s , ** xs , ** {var :x_i }), None ), s , axis )
306+ return s
307+ return new_loop
308+
309+ # compose
310+ for key ,value in continous_vars .items ():
311+ loop = scan_one_more (loop , key ,value )
312+
313+ return partial (loop , s = - jnp .inf )
314+
315+
316+
220317def _base_productmap (func : F , product_axes : list [Array ]) -> F :
221318 """Map func over the Cartesian product of product_axes.
222319
@@ -235,7 +332,6 @@ def _base_productmap(func: F, product_axes: list[Array]) -> F:
235332 parameters = list (signature .parameters )
236333
237334 positions = [parameters .index (ax ) for ax in product_axes ]
238-
239335 vmap_specs = []
240336 # We iterate in reverse order such that the output dimensions are in the same order
241337 # as the input dimensions.
0 commit comments