Skip to content

Commit 0f41ea8

Browse files
Fix assertions for index
1 parent d2483b9 commit 0f41ea8

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

tests/pandas/test_agentset_pandas.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@tg.typechecked
1313
class ExampleAgentSetPandas(AgentSetPandas):
14-
def __init__(self, model: ModelDF):
14+
def __init__(self, model: ModelDF, index: pd.Index):
1515
super().__init__(model)
1616
self.starting_wealth = pd.Series([1, 2, 3, 4], name="wealth")
1717

@@ -25,7 +25,7 @@ def step(self) -> None:
2525
@pytest.fixture
2626
def fix1_AgentSetPandas() -> ExampleAgentSetPandas:
2727
model = ModelDF()
28-
agents = ExampleAgentSetPandas(model)
28+
agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3], name="unique_id"))
2929
agents.add({"unique_id": [0, 1, 2, 3]})
3030
agents["wealth"] = agents.starting_wealth
3131
agents["age"] = [10, 20, 30, 40]
@@ -36,7 +36,7 @@ def fix1_AgentSetPandas() -> ExampleAgentSetPandas:
3636
@pytest.fixture
3737
def fix2_AgentSetPandas() -> ExampleAgentSetPandas:
3838
model = ModelDF()
39-
agents = ExampleAgentSetPandas(model)
39+
agents = ExampleAgentSetPandas(model, pd.Index([4, 5, 6, 7], name="unique_id"))
4040
agents.add({"unique_id": [4, 5, 6, 7]})
4141
agents["wealth"] = agents.starting_wealth + 10
4242
agents["age"] = [100, 200, 300, 400]
@@ -55,7 +55,7 @@ def fix1_AgentSetPandas_with_pos(fix1_AgentSetPandas) -> ExampleAgentSetPandas:
5555
class Test_AgentSetPandas:
5656
def test__init__(self):
5757
model = ModelDF()
58-
agents = ExampleAgentSetPandas(model)
58+
agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3]))
5959
assert agents.model == model
6060
assert isinstance(agents.agents, pd.DataFrame)
6161
assert agents.agents.index.name == "unique_id"
@@ -78,14 +78,15 @@ def test_add(
7878

7979
# Test with a list (Sequence[Any])
8080
result = agents.add([10, 5, 10], inplace=False)
81-
assert result.agents.index.to_list() == [0, 1, 2, 3, 10]
81+
assert result.agents.index.to_list() == [0, 1, 2, 3, 4]
8282
assert result.agents.wealth.to_list() == [1, 2, 3, 4, 5]
8383
assert result.agents.age.to_list() == [10, 20, 30, 40, 10]
8484
assert agents.agents.index.name == "unique_id"
8585

8686
# Test with a dict[str, Any]
8787
agents.add({"unique_id": [4, 5], "wealth": [5, 6], "age": [50, 60]})
8888
assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5, 6]
89+
assert agents.agents.index.tolist() == [0, 1, 2, 3, 4, 5]
8990
assert agents.agents.age.tolist() == [10, 20, 30, 40, 50, 60]
9091
assert agents.agents.index.name == "unique_id"
9192

@@ -286,6 +287,7 @@ def test__add__(
286287

287288
# Test with an AgentSetPandas and a dict
288289
agents3 = agents + {"unique_id": 10, "wealth": 5}
290+
assert agents3.agents.index.tolist() == [0, 1, 2, 3, 4]
289291
assert agents3.agents.wealth.tolist() == [1, 2, 3, 4, 5]
290292

291293
def test__contains__(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
@@ -358,6 +360,7 @@ def test__iadd__(
358360
# Test with an AgentSetPandas and a dict
359361
agents = deepcopy(fix1_AgentSetPandas)
360362
agents += {"unique_id": 10, "wealth": 5}
363+
assert agents.agents.index.tolist() == [0, 1, 2, 3, 4]
361364
assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5]
362365

363366
def test__iter__(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
@@ -436,24 +439,25 @@ def test_agents(
436439

437440
# Test agents.setter
438441
agents.agents = agents2.agents
439-
assert len(agents.active_agents) == 4
442+
assert agents.agents.wealth.tolist() == [11, 12, 13, 14]
443+
assert agents.agents.age.tolist() == [100, 200, 300, 400]
440444

441445
def test_active_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
442446
agents = fix1_AgentSetPandas
443447

444448
# Test with select
445449
agents.select(agents["wealth"] > 2, inplace=True)
446-
assert len(agents.active_agents) == 2
450+
assert agents.active_agents.index.tolist() == [2, 3]
447451

448452
# Test with active_agents.setter
449453
agents.active_agents = agents.agents.wealth > 2
450-
assert len(agents.active_agents) == 2
454+
assert agents.active_agents.index.to_list() == [2, 3]
451455

452456
def test_inactive_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
453457
agents = fix1_AgentSetPandas
454458

455459
agents.select(agents["wealth"] > 2, inplace=True)
456-
assert len(agents.active_agents) == 2
460+
assert agents.inactive_agents.index.to_list() == [0, 1]
457461

458462
def test_pos(self, fix1_AgentSetPandas_with_pos: ExampleAgentSetPandas):
459463
pos = fix1_AgentSetPandas_with_pos.pos

0 commit comments

Comments
 (0)