Skip to content

Commit 9f71b07

Browse files
Generate automatically unique_id for pandas agent set
1 parent dfddd80 commit 9f71b07

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

mesa_frames/concrete/pandas/agentset.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,13 @@ def step(self):
5151
refer to the class docstring.
5252
"""
5353

54+
from collections import defaultdict
5455
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
56+
import functools
57+
import itertools
58+
from logging import warning
5559
from typing import TYPE_CHECKING
60+
import warnings
5661

5762
import numpy as np
5863
import pandas as pd
@@ -79,6 +84,7 @@ class AgentSetPandas(AgentSetDF, PandasMixin):
7984
8085
"""
8186

87+
_ids = defaultdict(functools.partial(itertools.count, 0))
8288
_agents: pd.DataFrame
8389
_mask: pd.Series
8490
_copy_with_method: dict[str, tuple[str, list[str]]] = {
@@ -113,23 +119,26 @@ def add( # noqa : D102
113119
obj = self._get_obj(inplace)
114120
if isinstance(agents, pd.DataFrame):
115121
new_agents = agents
116-
if "unique_id" != agents.index.name:
117-
try:
118-
new_agents.set_index("unique_id", inplace=True, drop=True)
119-
except KeyError:
120-
raise KeyError("DataFrame must have a unique_id column/index.")
122+
if "unique_id" == agents.index.name:
123+
warnings.warn("DataFrame has 'unique_id' column, which will be ignored.")
124+
new_agents = new_agents.reindex([next(self._ids[self.model]) for _ in agents.index])
121125
elif isinstance(agents, dict):
122-
if "unique_id" not in agents:
123-
raise KeyError("Dictionary must have a unique_id key.")
124-
index = agents.pop("unique_id")
125-
if not isinstance(index, list):
126-
index = [index]
126+
if "unique_id" in agents:
127+
warnings.warn("Dictionary contains a 'unique_id' key, which will be ignored.")
128+
if isinstance(agents["unique_id"], list):
129+
index = [next(self._ids[self.model]) for _ in agents["unique_id"]]
130+
else:
131+
index = [next(self._ids[self.model])]
132+
agents.pop("unique_id")
127133
new_agents = pd.DataFrame(agents, index=pd.Index(index, name="unique_id"))
128134
else:
129135
if len(agents) != len(obj._agents.columns) + 1:
130136
raise ValueError(
131137
"Length of data must match the number of columns in the AgentSet if being added as a Collection."
132138
)
139+
if len(agents) == len(obj._agents.columns):
140+
# we suppose the first element of the list is unique_id
141+
agents[0] = next(self._ids[self.model])
133142
columns = pd.Index(["unique_id"]).append(obj._agents.columns.copy())
134143
new_agents = pd.DataFrame([agents], columns=columns).set_index(
135144
"unique_id", drop=True

tests/pandas/test_agentset_pandas.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
@tg.typechecked
1313
class ExampleAgentSetPandas(AgentSetPandas):
14-
def __init__(self, model: ModelDF, index: pd.Index):
14+
def __init__(self, model: ModelDF):
1515
super().__init__(model)
16-
self.starting_wealth = pd.Series([1, 2, 3, 4], name="wealth", index=index)
16+
self.starting_wealth = pd.Series([1, 2, 3, 4], name="wealth")
1717

1818
def add_wealth(self, amount: int) -> None:
1919
self.agents["wealth"] += amount
@@ -25,7 +25,7 @@ def step(self) -> None:
2525
@pytest.fixture
2626
def fix1_AgentSetPandas() -> ExampleAgentSetPandas:
2727
model = ModelDF()
28-
agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3], name="unique_id"))
28+
agents = ExampleAgentSetPandas(model)
2929
agents.add({"unique_id": [0, 1, 2, 3]})
3030
agents["wealth"] = agents.starting_wealth
3131
agents["age"] = [10, 20, 30, 40]
@@ -36,7 +36,7 @@ def fix1_AgentSetPandas() -> ExampleAgentSetPandas:
3636
@pytest.fixture
3737
def fix2_AgentSetPandas() -> ExampleAgentSetPandas:
3838
model = ModelDF()
39-
agents = ExampleAgentSetPandas(model, pd.Index([4, 5, 6, 7], name="unique_id"))
39+
agents = ExampleAgentSetPandas(model)
4040
agents.add({"unique_id": [4, 5, 6, 7]})
4141
agents["wealth"] = agents.starting_wealth + 10
4242
agents["age"] = [100, 200, 300, 400]
@@ -55,7 +55,7 @@ def fix1_AgentSetPandas_with_pos(fix1_AgentSetPandas) -> ExampleAgentSetPandas:
5555
class Test_AgentSetPandas:
5656
def test__init__(self):
5757
model = ModelDF()
58-
agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3]))
58+
agents = ExampleAgentSetPandas(model)
5959
assert agents.model == model
6060
assert isinstance(agents.agents, pd.DataFrame)
6161
assert agents.agents.index.name == "unique_id"
@@ -86,7 +86,6 @@ def test_add(
8686
# Test with a dict[str, Any]
8787
agents.add({"unique_id": [4, 5], "wealth": [5, 6], "age": [50, 60]})
8888
assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5, 6]
89-
assert agents.agents.index.tolist() == [0, 1, 2, 3, 4, 5]
9089
assert agents.agents.age.tolist() == [10, 20, 30, 40, 50, 60]
9190
assert agents.agents.index.name == "unique_id"
9291

@@ -287,7 +286,6 @@ def test__add__(
287286

288287
# Test with an AgentSetPandas and a dict
289288
agents3 = agents + {"unique_id": 10, "wealth": 5}
290-
assert agents3.agents.index.tolist() == [0, 1, 2, 3, 10]
291289
assert agents3.agents.wealth.tolist() == [1, 2, 3, 4, 5]
292290

293291
def test__contains__(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
@@ -360,7 +358,6 @@ def test__iadd__(
360358
# Test with an AgentSetPandas and a dict
361359
agents = deepcopy(fix1_AgentSetPandas)
362360
agents += {"unique_id": 10, "wealth": 5}
363-
assert agents.agents.index.tolist() == [0, 1, 2, 3, 10]
364361
assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5]
365362

366363
def test__iter__(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
@@ -439,24 +436,24 @@ def test_agents(
439436

440437
# Test agents.setter
441438
agents.agents = agents2.agents
442-
assert agents.agents.index.tolist() == [4, 5, 6, 7]
439+
assert len(agents.active_agents) == 4
443440

444441
def test_active_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
445442
agents = fix1_AgentSetPandas
446443

447444
# Test with select
448445
agents.select(agents["wealth"] > 2, inplace=True)
449-
assert agents.active_agents.index.tolist() == [2, 3]
446+
assert len(agents.active_agents) == 2
450447

451448
# Test with active_agents.setter
452449
agents.active_agents = agents.agents.wealth > 2
453-
assert agents.active_agents.index.to_list() == [2, 3]
450+
assert len(agents.active_agents) == 2
454451

455452
def test_inactive_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
456453
agents = fix1_AgentSetPandas
457454

458455
agents.select(agents["wealth"] > 2, inplace=True)
459-
assert agents.inactive_agents.index.to_list() == [0, 1]
456+
assert len(agents.active_agents) == 2
460457

461458
def test_pos(self, fix1_AgentSetPandas_with_pos: ExampleAgentSetPandas):
462459
pos = fix1_AgentSetPandas_with_pos.pos

0 commit comments

Comments
 (0)