Skip to content

Commit 03b0275

Browse files
Warn when passing unique_id while adding new agents
1 parent 0f41ea8 commit 03b0275

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

mesa_frames/concrete/pandas/agentset.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ def step(self):
5151
refer to the class docstring.
5252
"""
5353

54-
from collections import defaultdict
5554
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
56-
import functools
57-
import itertools
58-
from logging import warning
5955
from typing import TYPE_CHECKING
6056
import warnings
6157

@@ -84,7 +80,6 @@ class AgentSetPandas(AgentSetDF, PandasMixin):
8480
8581
"""
8682

87-
_ids = defaultdict(functools.partial(itertools.count, 0))
8883
_agents: pd.DataFrame
8984
_mask: pd.Series
9085
_copy_with_method: dict[str, tuple[str, list[str]]] = {
@@ -119,32 +114,33 @@ def add( # noqa : D102
119114
obj = self._get_obj(inplace)
120115
if isinstance(agents, pd.DataFrame):
121116
new_agents = agents
122-
if "unique_id" == agents.index.name:
123-
warnings.warn("DataFrame has 'unique_id' column, which will be ignored.")
124-
new_agents = new_agents.reindex([next(self._ids[self.model]) for _ in agents.index])
117+
if "unique_id" == agents.index.name or "unique_id" in agents.columns:
118+
warnings.warn("Dataframe should not have a unique_id index/column. It will be ignored.")
125119
elif isinstance(agents, dict):
126120
if "unique_id" in agents:
127-
warnings.warn("Dictionary contains a 'unique_id' key, which will be ignored.")
128-
if isinstance(agents["unique_id"], list):
129-
index = [next(self._ids[self.model]) for _ in agents["unique_id"]]
130-
else:
131-
index = [next(self._ids[self.model])]
132-
agents.pop("unique_id")
133-
new_agents = pd.DataFrame(agents, index=pd.Index(index, name="unique_id"))
121+
warnings.warn("Dictionary should not have a unique_id key. It will be ignored.")
122+
if isinstance(next(iter(agents.values())), list):
123+
index = range(len(next(iter(agents.values()))))
124+
else:
125+
index = [0]
126+
new_agents = pd.DataFrame(agents, index=index)
134127
else:
135-
if len(agents) != len(obj._agents.columns) + 1:
128+
if len(agents) not in {len(obj._agents.columns), len(obj._agents.columns) + 1}:
136129
raise ValueError(
137130
"Length of data must match the number of columns in the AgentSet if being added as a Collection."
138131
)
139-
if len(agents) == len(obj._agents.columns):
140-
agents = next(self._ids[self.model]) + agents[1:]
141-
columns = pd.Index(["unique_id"]).append(obj._agents.columns.copy())
142-
new_agents = pd.DataFrame([agents], columns=columns).set_index(
143-
"unique_id", drop=True
144-
)
145-
146-
if new_agents.index.dtype != "int64":
147-
new_agents.index = new_agents.index.astype("int64")
132+
if len(agents) == len(obj._agents.columns) + 1:
133+
warnings.warn("Length of data should have the number of columns in the AgentSet," +
134+
"we suppose the first element is the unique_id. It will be ignored.")
135+
agents = agents[1:]
136+
new_agents = pd.DataFrame([agents], columns=obj._agents.columns.copy())
137+
138+
new_agents.drop("unique_id", errors="ignore", inplace=True)
139+
if len(self.agents) == 0:
140+
new_agents["unique_id"] = np.arange(len(new_agents))
141+
else:
142+
new_agents["unique_id"] = np.arange(self.index.max() + 1, self.index.max() + 1 + len(new_agents))
143+
new_agents.set_index("unique_id", inplace=True, drop=True)
148144

149145
if not obj._agents.index.intersection(new_agents.index).empty:
150146
raise KeyError("Some IDs already exist in the agent set.")
@@ -223,7 +219,7 @@ def set( # noqa : D102
223219
"Either attr_names must be a dictionary with columns as keys and values or values must be provided."
224220
)
225221

226-
non_masked_df = obj._agents[~b_mask]
222+
non_masked_df = obj._agents[~b_mask] if len(b_mask) > 0 else pd.DataFrame()
227223
original_index = obj._agents.index
228224
obj._agents = pd.concat([non_masked_df, masked_df])
229225
obj._agents = obj._agents.reindex(original_index)

0 commit comments

Comments
 (0)