Skip to content

Commit 9f7d565

Browse files
Merge pull request #43 from leonvanbokhorst/Refactor-PhilosophyTutorAgent-to-inherit-from-GenericAgent
Refactor PhilosophyTutorAgent to inherit from GenericAgent
2 parents a87b03d + 4ee6f7f commit 9f7d565

File tree

3 files changed

+70
-66
lines changed

3 files changed

+70
-66
lines changed

src/active_inference_forager/agents/dqn_fep_agent.py renamed to src/active_inference_forager/agents/generic_agent.py

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from active_inference_forager.agents.base_agent import BaseAgent
1111
from active_inference_forager.agents.belief_node import BeliefNode
12-
from active_inference_forager.utils.numpy_fields import NumpyArrayField
1312

1413

1514
class ExperienceReplayBuffer:
@@ -41,7 +40,7 @@ def forward(self, x):
4140
return self.network(x)
4241

4342

44-
class DQNFEPAgent(BaseAgent):
43+
class GenericAgent(BaseAgent):
4544
# FEP-related parameters
4645
max_kl: float = Field(default=10.0)
4746
max_fe: float = Field(default=100.0)
@@ -86,12 +85,10 @@ class DQNFEPAgent(BaseAgent):
8685
def __init__(self, state_dim: int, action_dim: int, **kwargs):
8786
super().__init__(state_dim=state_dim, action_dim=action_dim, **kwargs)
8887

89-
# Initialize root belief with correct dimensions
9088
self.root_belief = BeliefNode(
9189
mean=np.zeros(state_dim), precision=np.eye(state_dim) * 0.1
9290
)
9391

94-
# Initialize DQN components
9592
self.q_network = self._build_network()
9693
self.target_network = self._build_network()
9794
self.target_network.load_state_dict(self.q_network.state_dict())
@@ -127,11 +124,9 @@ def learn(
127124
batch = self.replay_buffer.sample(self.batch_size)
128125
states, actions, rewards, next_states, dones = zip(*batch)
129126

130-
# Convert lists of numpy arrays to single numpy arrays
131127
states = np.array(states)
132128
next_states = np.array(next_states)
133129

134-
# Convert numpy arrays to tensors
135130
states = torch.FloatTensor(states).to(self.device)
136131
actions = torch.LongTensor(
137132
[self.action_space.index(action) for action in actions]
@@ -153,20 +148,32 @@ def learn(
153148
self.soft_update_target_network()
154149
self.decay_exploration()
155150

151+
def interpret_action(self, action: str) -> str:
152+
"""
153+
Interpret the agent's action in a human-readable format.
154+
"""
155+
action_interpretations = {
156+
"ask_question": "The agent decides to ask a question to gather more information.",
157+
"provide_information": "The agent provides relevant information to the user.",
158+
"clarify": "The agent attempts to clarify a point or resolve any confusion.",
159+
"suggest_action": "The agent suggests a specific action or solution to the user.",
160+
"express_empathy": "The agent expresses empathy or understanding towards the user's situation.",
161+
"end_conversation": "The agent determines it's appropriate to end the conversation.",
162+
}
163+
return action_interpretations.get(action, f"Unknown action: {action}")
164+
156165
def update_belief(self, observation: np.ndarray) -> None:
157-
# Validate that observation is numeric
158166
if not np.issubdtype(observation.dtype, np.number):
159167
raise ValueError("Observation must be a numeric array.")
160-
168+
161169
self._update_belief_recursive(self.root_belief, observation)
162170
self._regularize_beliefs()
163171

164172
def _update_belief_recursive(self, node: BeliefNode, observation: np.ndarray):
165-
# Ensure observation is a numpy array of floats
166173
observation = np.asarray(observation)
167174
if observation.dtype != node.mean.dtype:
168175
observation = observation.astype(node.mean.dtype)
169-
176+
170177
prediction_error = observation - node.mean
171178
node.precision += (
172179
np.outer(prediction_error, prediction_error) * self.learning_rate
@@ -190,13 +197,7 @@ def update_free_energy(self):
190197
self.free_energy = self._calculate_free_energy_recursive(self.root_belief)
191198

192199
def _build_network(self):
193-
return nn.Sequential(
194-
nn.Linear(self.state_dim, 128),
195-
nn.ReLU(),
196-
nn.Linear(128, 128),
197-
nn.ReLU(),
198-
nn.Linear(128, self.action_dim),
199-
)
200+
return DQN(self.state_dim, self.action_dim).to(self.device)
200201

201202
def _calculate_free_energy_recursive(self, node: BeliefNode) -> float:
202203
kl_divergence = self._kl_divergence(node)
@@ -272,48 +273,20 @@ def _build_belief_hierarchy(self, node: BeliefNode, level: int):
272273
node.children[action] = child
273274
self._build_belief_hierarchy(child, level + 1)
274275

275-
def interpret_action(self, action: str) -> str:
276-
"""
277-
Interpret the agent's action in a human-readable format.
278-
"""
279-
action_interpretations = {
280-
"ask_question": "The agent decides to ask a question to gather more information.",
281-
"provide_information": "The agent provides relevant information to the user.",
282-
"clarify": "The agent attempts to clarify a point or resolve any confusion.",
283-
"suggest_action": "The agent suggests a specific action or solution to the user.",
284-
"express_empathy": "The agent expresses empathy or understanding towards the user's situation.",
285-
"end_conversation": "The agent determines it's appropriate to end the conversation.",
286-
}
287-
return action_interpretations.get(action, f"Unknown action: {action}")
288-
289276
def process_user_input(self, user_input: str) -> np.ndarray:
290-
"""
291-
Simple natural language processing to extract features from user input.
292-
"""
293-
# This is a very basic implementation and can be expanded with more sophisticated NLP techniques
294-
features = np.zeros(5) # Assuming 5 features for simplicity
277+
features = np.zeros(5)
295278

296279
words = user_input.split()
297-
features[0] = len(words) # Number of words
298-
features[1] = user_input.count("?") / len(words) # Question mark ratio
299-
features[2] = user_input.count("!") / len(words) # Exclamation mark ratio
300-
features[3] = len(user_input) / 100 # Normalized length of input
280+
features[0] = len(words)
281+
features[1] = user_input.count("?") / len(words)
282+
features[2] = user_input.count("!") / len(words)
283+
features[3] = len(user_input) / 100
301284
features[4] = sum(
302285
1
303286
for word in words
304287
if word.lower() in ["please", "thank", "thanks", "appreciate"]
305-
) / len(
306-
words
307-
) # Politeness ratio
288+
) / len(words)
308289

309-
# Ensure features are of type float
310290
features = features.astype(float)
311291

312-
# Debug print statements
313-
print(f"Debug: Input string: '{user_input}'")
314-
print(f"Debug: Word count: {len(words)}")
315-
print(f"Debug: Question mark count: {user_input.count('?')}")
316-
print(f"Debug: Exclamation mark count: {user_input.count('!')}")
317-
print(f"Debug: Features: {features}")
318-
319292
return features

src/active_inference_forager/agents/philosophy_tutor_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
from typing import Dict, List
33
from pydantic import Field
4-
from active_inference_forager.agents.dqn_fep_agent import DQNFEPAgent
4+
from active_inference_forager.agents.generic_agent import GenericAgent
55

66

7-
class PhilosophyTutorAgent(DQNFEPAgent):
7+
class PhilosophyTutorAgent(GenericAgent):
88
knowledge_base: Dict[str, Dict] = Field(default_factory=dict)
99

1010
def __init__(self, state_dim: int, action_dim: int, **kwargs):

tests/unit/test_dqn_fep_agent.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,119 @@
11
import pytest
22
import numpy as np
33
import torch
4-
from active_inference_forager.agents.dqn_fep_agent import DQNFEPAgent
4+
from active_inference_forager.agents.generic_agent import GenericAgent
5+
56

67
@pytest.fixture
78
def agent():
8-
action_space = ['ask_question', 'provide_information', 'clarify', 'suggest_action', 'express_empathy', 'end_conversation']
9-
return DQNFEPAgent(state_dim=5, action_dim=len(action_space), action_space=action_space)
9+
action_space = [
10+
"ask_question",
11+
"provide_information",
12+
"clarify",
13+
"suggest_action",
14+
"express_empathy",
15+
"end_conversation",
16+
]
17+
return GenericAgent(
18+
state_dim=5, action_dim=len(action_space), action_space=action_space
19+
)
20+
1021

1122
def test_agent_initialization(agent):
12-
assert isinstance(agent, DQNFEPAgent)
23+
assert isinstance(agent, GenericAgent)
1324
assert agent.state_dim == 5
14-
assert np.array_equal(agent.action_space, ['ask_question', 'provide_information', 'clarify', 'suggest_action', 'express_empathy', 'end_conversation'])
25+
assert np.array_equal(
26+
agent.action_space,
27+
[
28+
"ask_question",
29+
"provide_information",
30+
"clarify",
31+
"suggest_action",
32+
"express_empathy",
33+
"end_conversation",
34+
],
35+
)
1536
assert isinstance(agent.q_network, torch.nn.Module)
1637
assert isinstance(agent.target_network, torch.nn.Module)
1738
assert isinstance(agent.optimizer, torch.optim.Adam)
1839
assert agent.exploration_rate == agent.epsilon_start
1940

41+
2042
def test_take_action(agent):
2143
state = np.random.rand(5)
2244
action = agent.take_action(state)
2345
assert action in agent.action_space
2446

47+
2548
def test_learn(agent):
2649
state = np.random.rand(5)
27-
action = 'ask_question'
50+
action = "ask_question"
2851
next_state = np.random.rand(5)
2952
reward = 1.0
3053
done = False
3154

3255
initial_total_steps = agent.total_steps
33-
56+
3457
agent.learn(state, action, next_state, reward, done)
35-
58+
3659
assert agent.total_steps > initial_total_steps
3760

61+
3862
def test_update_belief(agent):
3963
observation = np.random.rand(5)
4064
initial_mean = agent.root_belief.mean.copy()
4165
initial_precision = agent.root_belief.precision.copy()
42-
66+
4367
agent.update_belief(observation)
44-
68+
4569
assert not np.array_equal(initial_mean, agent.root_belief.mean)
4670
assert not np.array_equal(initial_precision, agent.root_belief.precision)
4771

72+
4873
def test_update_free_energy(agent):
4974
initial_free_energy = agent.free_energy
5075
agent.update_free_energy()
5176
assert agent.free_energy != initial_free_energy
5277

78+
5379
def test_update_reward_buffer(agent):
5480
initial_buffer_length = len(agent.reward_buffer)
5581
agent.update_reward_buffer(1.0)
5682
assert len(agent.reward_buffer) == initial_buffer_length + 1
5783

84+
5885
def test_decay_exploration(agent):
5986
initial_exploration_rate = agent.exploration_rate
6087
agent.decay_exploration()
6188
assert agent.exploration_rate < initial_exploration_rate
6289

90+
6391
def test_reset(agent):
6492
agent.free_energy = 10.0
6593
agent.exploration_rate = 0.1
6694
agent.reset()
6795
assert agent.free_energy == 0.0
6896
assert agent.exploration_rate == agent.epsilon_start
6997

98+
7099
def test_interpret_action(agent):
71-
action = 'ask_question'
100+
action = "ask_question"
72101
interpretation = agent.interpret_action(action)
73102
assert isinstance(interpretation, str)
74103
assert "ask a question" in interpretation.lower()
75104

105+
76106
def test_process_user_input(agent):
77107
user_input = "Hello, how are you? I'm feeling great today!"
78108
features = agent.process_user_input(user_input)
79109
assert isinstance(features, np.ndarray)
80110
assert features.shape == (5,)
81111
assert features[0] == 8 # Number of words
82-
assert features[1] == 1/8 # Question mark ratio
83-
assert features[2] == 1/8 # Exclamation mark ratio
112+
assert features[1] == 1 / 8 # Question mark ratio
113+
assert features[2] == 1 / 8 # Exclamation mark ratio
84114
assert 0 < features[3] < 1 # Normalized length
85115
assert features[4] == 0 # Politeness ratio
86116

117+
87118
if __name__ == "__main__":
88119
pytest.main()

0 commit comments

Comments
 (0)