Skip to content

Commit a679200

Browse files
Let add agents without specifing unique_id for polars and adapt AgentsDF for handling the unique_id
1 parent da9e095 commit a679200

File tree

7 files changed

+77
-26
lines changed

7 files changed

+77
-26
lines changed

mesa_frames/abstract/agents.py

+17
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,23 @@ def sort(
382382
A new or updated AgentContainer.
383383
"""
384384

385+
@abstractmethod
386+
def shift_indexes(self, first_index: int, inplace: bool = True) -> Self:
387+
"""Shift the indexes of the agents in the AgentContainer by the specified amount.
388+
389+
Parameters
390+
----------
391+
first_index : int
392+
The new first index to be used.
393+
inplace : bool
394+
Whether to shift the indexes in place.
395+
396+
Returns
397+
-------
398+
Self
399+
A new or updated AgentContainer.
400+
"""
401+
385402
def __add__(
386403
self, other: DataFrameInput | AgentSetDF | Collection[AgentSetDF]
387404
) -> Self:

mesa_frames/concrete/agents.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,11 @@ def add(
110110
other_list = obj._return_agentsets_list(agents)
111111
if obj._check_agentsets_presence(other_list).any():
112112
raise ValueError("Some agentsets are already present in the AgentsDF.")
113-
new_ids = pl.concat(
114-
[obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list]
115-
)
116-
if new_ids.is_duplicated().any():
117-
raise ValueError("Some of the agent IDs are not unique.")
118-
obj._agentsets.extend(other_list)
119-
obj._ids = new_ids
113+
for agentset in other_list:
114+
if len(obj._agentsets) > 0:
115+
agentset.shift_indexes(obj._ids.max() + 1)
116+
obj._agentsets.append(agentset)
117+
obj._ids = pl.concat([obj._ids, pl.Series(agentset["unique_id"])])
120118
return obj
121119

122120
@overload
@@ -607,3 +605,11 @@ def index(self) -> dict[AgentSetDF, Index]:
607605
@property
608606
def pos(self) -> dict[AgentSetDF, DataFrame]:
609607
return {agentset: agentset.pos for agentset in self._agentsets}
608+
609+
def shift_indexes(self, first_index: int, inplace: bool = True) -> Self:
610+
obj = self._get_obj(inplace)
611+
obj._ids += first_index
612+
for agentset in obj._agentsets:
613+
agentset.shift_indexes(first_index)
614+
first_index += len(agentset)
615+
return obj

mesa_frames/concrete/pandas/agentset.py

+5
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,11 @@ def to_polars(self) -> AgentSetPolars:
279279
new_obj._agents = pl.DataFrame(self._agents)
280280
new_obj._mask = pl.Series(self._mask)
281281
return new_obj
282+
283+
def shift_indexes(self, first_index: int, inplace: bool = True):
284+
obj = self._get_obj(inplace)
285+
obj._agents.index = np.arange(first_index, first_index + len(obj._agents))
286+
return obj
282287

283288
def _concatenate_agentsets(
284289
self,

mesa_frames/concrete/polars/agentset.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def step(self):
5959

6060
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
6161
from typing import TYPE_CHECKING
62+
import warnings
6263

6364
import polars as pl
6465
from polars._typing import IntoExpr
@@ -120,20 +121,28 @@ def add(
120121
"""
121122
obj = self._get_obj(inplace)
122123
if isinstance(agents, pl.DataFrame):
123-
if "unique_id" not in agents.columns:
124-
raise KeyError("DataFrame must have a unique_id column.")
124+
if "unique_id" in agents.columns:
125+
warnings.warn("Dataframe should not have a unique_id index/column. It will be ignored.")
125126
new_agents = agents
126127
elif isinstance(agents, dict):
127-
if "unique_id" not in agents:
128-
raise KeyError("Dictionary must have a unique_id key.")
128+
if "unique_id" in agents:
129+
warnings.warn("Dictionary should not have a unique_id key. It will be ignored.")
129130
new_agents = pl.DataFrame(agents)
130131
else:
131-
if len(agents) != len(obj._agents.columns):
132+
if len(agents) not in {len(obj._agents.columns), len(obj._agents.columns) + 1}:
132133
raise ValueError(
133134
"Length of data must match the number of columns in the AgentSet if being added as a Collection."
134135
)
136+
if len(agents) == len(obj._agents.columns):
137+
warnings.warn("Length of data should have the number of columns in the AgentSet," +
138+
"we suppose the first element is the unique_id. It will be ignored.")
135139
new_agents = pl.DataFrame([agents], schema=obj._agents.schema)
136140

141+
if len(self.agents) == 0:
142+
unique_ids = pl.arange(len(new_agents))
143+
else:
144+
unique_ids = pl.arange(self.index.max() + 1, self.index.max() + 1 + len(new_agents))
145+
new_agents = new_agents.with_columns(unique_ids.alias("unique_id"))
137146
if new_agents["unique_id"].dtype != pl.Int64:
138147
raise TypeError("unique_id column must be of type int64.")
139148

@@ -293,7 +302,14 @@ def to_pandas(self) -> "AgentSetPandas":
293302
.to_pandas()
294303
)
295304
return new_obj
296-
305+
306+
def shift_indexes(self, first_index: int, inplace: bool = True):
307+
obj = self._get_obj(inplace)
308+
obj._agents = obj._agents.with_columns(
309+
pl.arange(first_index, first_index + len(obj._agents)).alias("unique_id")
310+
)
311+
return obj
312+
297313
def _concatenate_agentsets(
298314
self,
299315
agentsets: Iterable[Self],

tests/pandas/test_agentset_pandas.py

+5
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,8 @@ def test_pos(self, fix1_AgentSetPandas_with_pos: ExampleAgentSetPandas):
513513
assert all(math.isnan(val) for val in pos["dim_0"].tolist()[2:])
514514
assert pos["dim_1"].tolist()[:2] == [0, 1]
515515
assert all(math.isnan(val) for val in pos["dim_1"].tolist()[2:])
516+
517+
def test_shift_indexes(self, fix1_AgentSetPandas_with_unique_id: ExampleAgentSetPandas):
518+
agents = fix1_AgentSetPandas_with_unique_id
519+
agents.shift_indexes(10, inplace=True)
520+
assert agents.agents.index.tolist() == [10, 11, 12, 13]

tests/polars/test_agentset_polars.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def fix2_AgentSetPolars_with_unique_id() -> ExampleAgentSetPolars:
4343
model.agents.add(agents)
4444
space = GridPandas(model, dimensions=[3, 3], capacity=2)
4545
model.space = space
46-
space.place_agents(agents=[4, 5], pos=[[2, 1], [1, 2]])
46+
space.place_agents(agents=[0, 1], pos=[[2, 1], [1, 2]])
4747
return agents
4848

4949

@@ -83,7 +83,7 @@ def test_add_with_unique_id(
8383

8484
# Test with a list (Sequence[Any])
8585
result = agents.add([10, 5, 10], inplace=False)
86-
assert result.agents["unique_id"].to_list() == [0, 1, 2, 3, 10]
86+
assert result.agents["unique_id"].to_list() == [0, 1, 2, 3, 4]
8787
assert result.agents["wealth"].to_list() == [1, 2, 3, 4, 5]
8888
assert result.agents["age"].to_list() == [10, 20, 30, 40, 10]
8989

@@ -290,7 +290,7 @@ def test__add__(
290290

291291
# Test with an AgentSetPolars and a dict
292292
agents3 = agents + {"unique_id": 10, "wealth": 5}
293-
assert agents3.agents["unique_id"].to_list() == [0, 1, 2, 3, 10]
293+
assert agents3.agents["unique_id"].to_list() == [0, 1, 2, 3, 4]
294294
assert agents3.agents["wealth"].to_list() == [1, 2, 3, 4, 5]
295295

296296
def test__contains__(self, fix1_AgentSetPolars_with_unique_id: ExampleAgentSetPolars):
@@ -362,7 +362,7 @@ def test__iadd__(
362362
# Test with an AgentSetPolars and a dict
363363
agents = deepcopy(fix1_AgentSetPolars_with_unique_id)
364364
agents += {"unique_id": 10, "wealth": 5}
365-
assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 10]
365+
assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 4]
366366
assert agents.agents["wealth"].to_list() == [1, 2, 3, 4, 5]
367367

368368
def test__iter__(self, fix1_AgentSetPolars_with_unique_id: ExampleAgentSetPolars):
@@ -441,7 +441,7 @@ def test_agents(
441441

442442
# Test agents.setter
443443
agents.agents = agents2.agents
444-
assert agents.agents["unique_id"].to_list() == [4, 5, 6, 7]
444+
assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3]
445445

446446
def test_active_agents(self, fix1_AgentSetPolars_with_unique_id: ExampleAgentSetPolars):
447447
agents = fix1_AgentSetPolars_with_unique_id
@@ -467,3 +467,8 @@ def test_pos(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars):
467467
assert pos.columns == ["unique_id", "dim_0", "dim_1"]
468468
assert pos["dim_0"].to_list() == [0, 1, None, None]
469469
assert pos["dim_1"].to_list() == [0, 1, None, None]
470+
471+
def test_shift_indexes(self, fix1_AgentSetPolars_with_unique_id: ExampleAgentSetPolars):
472+
agents = fix1_AgentSetPolars_with_unique_id
473+
agents.shift_indexes(10, inplace=True)
474+
assert agents.agents["unique_id"].to_list() == [10, 11, 12, 13]

tests/test_agents.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ def test_add(
7777
+ agentset_polars._agents["unique_id"].to_list()
7878
)
7979

80-
# Test if adding the same AgentSetDF raises ValueError
81-
with pytest.raises(ValueError):
82-
agents.add(agentset_pandas, inplace=False)
83-
8480
def test_contains(
8581
self, fix2_AgentSetPandas_with_unique_id: ExampleAgentSetPandas, fix_AgentsDF: AgentsDF
8682
):
@@ -629,10 +625,6 @@ def test___add__(
629625
+ agentset_polars._agents["unique_id"].to_list()
630626
)
631627

632-
# Test if adding the same AgentSetDF raises ValueError
633-
with pytest.raises(ValueError):
634-
result + agentset_pandas
635-
636628
def test___contains__(
637629
self, fix_AgentsDF: AgentsDF, fix2_AgentSetPandas_with_unique_id: ExampleAgentSetPandas
638630
):
@@ -995,3 +987,8 @@ def test_inactive_agents(self, fix_AgentsDF: AgentsDF):
995987
== agents1._agentsets[1].select(mask1, negate=True).active_agents
996988
)
997989
)
990+
991+
def test_shift_indexes(self, fix_AgentsDF: AgentsDF):
992+
agents = fix_AgentsDF
993+
agents.shift_indexes(20)
994+
assert agents._ids.to_list() == [20, 21, 22, 23, 24, 25, 26, 27]

0 commit comments

Comments
 (0)