1
1
import inspect
2
+ import functools
3
+ import jax
2
4
from collections .abc import Callable
5
+ from functools import partial
3
6
from typing import Literal , TypeVar
4
7
from jax import Array , vmap
5
8
from jax import numpy as jnp
6
- from jax .lax import map
9
+ from jax .lax import map , concatenate , scan
7
10
from lcm .functools import allow_args , allow_only_kwargs
8
11
from lcm .typing import ParamsDict
12
+ from jax .sharding import Mesh , PartitionSpec as P
13
+ from jax .experimental .shard_map import shard_map
9
14
10
15
F = TypeVar ("F" , bound = Callable [..., Array ])
11
16
@@ -15,8 +20,6 @@ def spacemap(
15
20
state_and_discrete_vars : dict [str :Array ],
16
21
continous_vars : dict [str :Array ],
17
22
memory_restriction : int ,
18
- params : ParamsDict ,
19
- vf_arr : Array ,
20
23
) -> F :
21
24
"""
22
25
Evaluate func along all state and discrete choice axes in a way that reduces the memory usage
@@ -47,11 +50,6 @@ def spacemap(
47
50
A callable that evaluates func along the provided dicrete choice and state axes.
48
51
49
52
"""
50
- # Check inputs and prepare function
51
- # ==================================================================================
52
-
53
- # jax.vmap cannot deal with keyword-only arguments
54
- func = allow_args (func )
55
53
56
54
# I removed duplicate and overlap checks because we are passing dicts now
57
55
# and overlap between state+dicrte and continous seems unlikely
@@ -71,10 +69,24 @@ def spacemap(
71
69
memory_strat [key ] = jnp .size (state_and_discrete_vars [key ])
72
70
else :
73
71
memory_strat [key ] = 1
74
-
75
- mapped = _base_productmap_map (func , state_and_discrete_vars , continous_vars , memory_strat ,params , vf_arr )
76
72
77
- return mapped
73
+ # Get the last entry, this should be a cont. state var.
74
+ # This will be the axis we paralellize over multiple devices.
75
+ * _ , paralell_axis_name = state_and_discrete_vars .keys ()
76
+ paralell_axis = state_and_discrete_vars [paralell_axis_name ]
77
+ state_and_discrete_vars .pop (paralell_axis_name )
78
+ # Construct function for axes that are not paralellized over devices
79
+ mapped = _base_productmap_map (func , state_and_discrete_vars , memory_strat )
80
+ # This is number of devices, could be settable by user
81
+ # Does not work if len(axis) not dividable by num. of devices
82
+ num_of_devices = jax .device_count ()
83
+ mesh = jax .make_mesh ((num_of_devices ,), ('x' ,))
84
+ # Create function so we can use shard_map
85
+ def mapped_paralell (x , params ,vf_arr ):
86
+ res = map (lambda x_i : mapped (** {'params' :params , 'vf_arr' : vf_arr ,paralell_axis_name :x_i }), x , batch_size = memory_strat [paralell_axis_name ])
87
+ return res
88
+ sharded_mapped = shard_map (mapped_paralell ,mesh ,in_specs = (P ('x' ),P (),P ()),out_specs = P ('x' ))
89
+ return partial (sharded_mapped , paralell_axis )
78
90
79
91
80
92
def vmap_1d (
@@ -146,7 +158,6 @@ def vmap_1d(
146
158
147
159
return out
148
160
149
-
150
161
def productmap (func : F , variables : list [str ]) -> F :
151
162
"""Apply vmap such that func is evaluated on the Cartesian product of variables.
152
163
@@ -187,9 +198,57 @@ def productmap(func: F, variables: list[str]) -> F:
187
198
vmapped .__signature__ = signature # type: ignore[attr-defined]
188
199
189
200
return allow_only_kwargs (vmapped )
201
+ def productmap_scan (func : F , variables : list [str ]) -> F :
202
+ """Apply the function to the cartesian product of variables.
203
+
204
+ If it is possible to apply vmap while respecting the memory restriction,
205
+ use stacked vmaps, otherwise start scanning over the variable axes.
206
+
207
+ Args:
208
+ func: The function to be dispatched.
209
+ variables: List with names of arguments that over which the Cartesian product
210
+ should be formed.
211
+
212
+ Returns:
213
+ A callable with the same arguments as func (but with an additional leading
214
+ dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
215
+ returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
216
+ dimensions, where k is the length of ``variables``. The order of the dimensions
217
+ is determined by the order of ``variables`` which can be different to the order
218
+ of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
219
+ behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
220
+ described above but there might be additional dimensions.
190
221
222
+ """
191
223
192
- def _base_productmap_map (func : F , state_and_discrete_vars : dict [str :Array ], continous_vars : dict [str :Array ], strat , params , vf_arr ) -> F :
224
+ # Calculate memory restriction
225
+ memory_restriction = 8 * (10 ** 9 )
226
+ vmapped_vars = {}
227
+ scanned_vars = {}
228
+ for key in variables .keys ():
229
+ memory_restriction = memory_restriction / jnp .size (variables [key ])
230
+ if memory_restriction > 1 :
231
+ vmapped_vars [key ] = variables [key ]
232
+ else :
233
+ scanned_vars [key ] = variables [key ]
234
+ # Apply as many vmaps as possible
235
+ func = _base_productmap (func = func , product_axes = list (vmapped_vars ))
236
+ func = allow_only_kwargs (func )
237
+ # Create ccv function, so we get the max of vmapped vars
238
+ @functools .wraps (func )
239
+ def compute_ccv (* args ,** kwargs ):
240
+ u ,f = func (* args ,** kwargs )
241
+ return u .max (where = f , initial = - jnp .inf )
242
+ partialled_compute_ccv = lambda * args ,** kwargs : compute_ccv (* args ,** kwargs ,** vmapped_vars )
243
+ # Aplly scan for all vars that could not be vmapped
244
+ if scanned_vars :
245
+ mapped = _base_productmap_scan (partialled_compute_ccv , scanned_vars )
246
+ return mapped
247
+
248
+ return partialled_compute_ccv
249
+
250
+
251
+ def _base_productmap_map (func : F , state_and_discrete_vars : dict [str :Array ], strat ) -> F :
193
252
"""Map func over the Cartesian product of state_and_discrete_vars.
194
253
195
254
All arguments needed for the evaluation of func are passed via keyword args upon creation,
@@ -206,17 +265,55 @@ def _base_productmap_map(func: F, state_and_discrete_vars: dict[str:Array], cont
206
265
A callable that maps func over the provided axes.
207
266
208
267
"""
209
- mapped = lambda ** vals : func ( ** continous_vars , vf_arr = vf_arr , params = params , ** vals )
268
+ mapped = func
210
269
def stack_maps (func , var , axis ):
211
270
def one_more (** xs ):
212
271
return map (lambda x_i : func (** xs , ** {var :x_i }), axis , batch_size = strat [var ])
213
272
return one_more
273
+
214
274
for key ,value in reversed (state_and_discrete_vars .items ()):
215
275
mapped = stack_maps (mapped ,key ,value )
216
276
217
-
218
277
return mapped
219
278
279
+
280
+ def _base_productmap_scan (func : F , continous_vars : dict [str :Array ]) -> F :
281
+ """Map func over the Cartesian product of state_and_discrete_vars.
282
+
283
+ All arguments needed for the evaluation of func are passed via keyword args upon creation,
284
+ the returned callable takes no arguments.
285
+
286
+ Args:
287
+ func: The function to be dispatched.
288
+ state_and_discrete_vars: Dict of names and values for each discrete choice axis and state axis.
289
+ continous_vars: Dict of names and values for each continous choice axis.
290
+ params: Parameters of the Model.
291
+ vf_arr: Discretized Value Function from previous period.
292
+
293
+ Returns:
294
+ A callable that maps func over the provided axes.
295
+
296
+ """
297
+ # We want the maximum of all the values in the scan
298
+ def loop (s , ** x_is ):
299
+ u = func (** x_is )
300
+ return jnp .maximum (u , s )
301
+
302
+ # induction case: scan over one argument, eliminating it
303
+ def scan_one_more (loop , var ,axis ):
304
+ def new_loop (s , ** xs ):
305
+ s , _ = scan (lambda s , x_i : (loop (s , ** xs , ** {var :x_i }), None ), s , axis )
306
+ return s
307
+ return new_loop
308
+
309
+ # compose
310
+ for key ,value in continous_vars .items ():
311
+ loop = scan_one_more (loop , key ,value )
312
+
313
+ return partial (loop , s = - jnp .inf )
314
+
315
+
316
+
220
317
def _base_productmap (func : F , product_axes : list [Array ]) -> F :
221
318
"""Map func over the Cartesian product of product_axes.
222
319
@@ -235,7 +332,6 @@ def _base_productmap(func: F, product_axes: list[Array]) -> F:
235
332
parameters = list (signature .parameters )
236
333
237
334
positions = [parameters .index (ax ) for ax in product_axes ]
238
-
239
335
vmap_specs = []
240
336
# We iterate in reverse order such that the output dimensions are in the same order
241
337
# as the input dimensions.
0 commit comments