Skip to content

Commit 5539599

Browse files
committed
Remove nearly all mentions of sparse variables from state-choice-space creation
1 parent 609bc04 commit 5539599

File tree

6 files changed

+23
-47
lines changed

6 files changed

+23
-47
lines changed

src/lcm/discrete_problem.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,19 @@ def _segment_logsumexp(a, segment_info):
196196
def _determine_dense_discrete_choice_axes(
197197
variable_info: pd.DataFrame,
198198
) -> tuple[int, ...] | None:
199-
"""Get axes of a state choice space that correspond to dense discrete choices.
200-
201-
Note: The dense choice axes determine over which axes we reduce the conditional
202-
continuation values using a non-segmented operation. The axes ordering of the
203-
conditional continuation value array is given by [sparse_variable, dense_variables].
204-
The dense continuous choice dimension is already reduced as we are working with
205-
the conditional continuation values.
199+
"""Get axes of a state-choice-space that correspond to discrete choices.
206200
207201
Args:
208-
variable_info (pd.DataFrame): DataFrame with information about the variables.
202+
variable_info: DataFrame with information about the variables.
209203
210204
Returns:
211-
tuple[int, ...] | None: A tuple of indices representing the axes in the value
212-
function that correspond to discrete choices. Returns None if there are no
213-
discrete choice axes.
205+
tuple[int, ...] | None: A tuple of indices representing the axes positions in
206+
the value function that correspond to discrete choices. Returns None if
207+
there are no discrete choice axes.
214208
215209
"""
216210
# List of dense variables excluding continuous choice variables.
217-
axes = variable_info.query(
218-
"~(is_choice & is_continuous)",
219-
).index.tolist()
211+
axes = variable_info.query("~(is_choice & is_continuous)").index.tolist()
220212

221213
choice_vars = set(variable_info.query("is_choice").index.tolist())
222214

src/lcm/entry_point.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,14 @@ def get_lcm_function(
9494

9595
# call state space creation function, append trivial items to their lists
9696
# ==============================================================================
97-
sc_space, space_info, state_indexer, segments = create_state_choice_space(
97+
sc_space, space_info = create_state_choice_space(
9898
model=_mod,
9999
is_last_period=is_last_period,
100100
)
101101

102+
state_indexer = {}
103+
segments = None
104+
102105
state_choice_spaces.append(sc_space)
103106
choice_segments.append(segments)
104107

src/lcm/state_space.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,13 @@
33
from lcm.interfaces import InternalModel, Space, SpaceInfo
44

55

6-
def create_state_choice_space(model: InternalModel, *, is_last_period: bool):
7-
"""Create a state choice space for the model.
6+
def create_state_choice_space(
7+
model: InternalModel, *, is_last_period: bool
8+
) -> tuple[Space, SpaceInfo]:
9+
"""Create a state-choice-space for the model.
810
9-
A state_choice_space is a compressed representation of all feasible states and the
10-
feasible discrete choices within that state. We currently use the following
11-
compressions:
12-
13-
We distinguish between dense and sparse variables (dense_vars and sparse_vars).
14-
Dense state or choice variables are those whose set of feasible values does not
15-
depend on any other state or choice variables. Sparse state or choice variables are
16-
all other state variables. For dense state variables it is thus enough to store the
17-
grid of feasible values (value_grid), whereas for sparse variables all feasible
18-
combinations (combination_grid) have to be stored.
19-
20-
Note:
21-
-----
22-
- We only use the filter mask, not the forward mask (yet).
11+
A state-choice-space is a compressed representation of all feasible states and the
12+
feasible discrete choices within that state.
2313
2414
Args:
2515
model (Model): A processed model.
@@ -30,9 +20,6 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool):
3020
to execute a function on an entire space.
3121
SpaceInfo: A SpaceInfo object that contains all information needed to work with
3222
the output of a function evaluated on the space.
33-
dict: Dictionary containing state indexer arrays.
34-
jnp.ndarray: Jax array containing the choice segments needed for the emax
35-
calculations.
3623
3724
"""
3825
# ==================================================================================
@@ -54,12 +41,6 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool):
5441
sparse_vars={},
5542
dense_vars=_value_grid,
5643
)
57-
# ==================================================================================
58-
# create indexers and segments
59-
# ==================================================================================
60-
choice_segments = None
61-
62-
state_indexers = {} # type: ignore[var-annotated]
6344

6445
# ==================================================================================
6546
# create state space info
@@ -85,7 +66,7 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool):
8566
indexer_infos=indexer_infos,
8667
)
8768

88-
return state_choice_space, space_info, state_indexers, choice_segments
69+
return state_choice_space, space_info
8970

9071

9172
def _create_value_grid(grids, subset):

tests/test_entry_point.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_create_compute_conditional_continuation_value():
182182
},
183183
}
184184

185-
_, space_info, _, _ = create_state_choice_space(
185+
_, space_info = create_state_choice_space(
186186
model=model,
187187
is_last_period=False,
188188
)
@@ -228,7 +228,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
228228
},
229229
}
230230

231-
_, space_info, _, _ = create_state_choice_space(
231+
_, space_info = create_state_choice_space(
232232
model=model,
233233
is_last_period=False,
234234
)
@@ -279,7 +279,7 @@ def test_create_compute_conditional_continuation_policy():
279279
},
280280
}
281281

282-
_, space_info, _, _ = create_state_choice_space(
282+
_, space_info = create_state_choice_space(
283283
model=model,
284284
is_last_period=False,
285285
)
@@ -326,7 +326,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
326326
},
327327
}
328328

329-
_, space_info, _, _ = create_state_choice_space(
329+
_, space_info = create_state_choice_space(
330330
model=model,
331331
is_last_period=False,
332332
)

tests/test_model_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_get_utility_and_feasibility_function():
5656
},
5757
}
5858

59-
_, space_info, _, _ = create_state_choice_space(
59+
_, space_info = create_state_choice_space(
6060
model=model,
6161
is_last_period=False,
6262
)

tests/test_simulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def simulate_inputs():
4040
model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1)
4141
model = process_model(model_config)
4242

43-
_, space_info, _, _ = create_state_choice_space(
43+
_, space_info = create_state_choice_space(
4444
model=model,
4545
is_last_period=False,
4646
)

0 commit comments

Comments
 (0)