9
9
10
10
from active_inference_forager .agents .base_agent import BaseAgent
11
11
from active_inference_forager .agents .belief_node import BeliefNode
12
- from active_inference_forager .utils .numpy_fields import NumpyArrayField
13
12
14
13
15
14
class ExperienceReplayBuffer :
@@ -41,7 +40,7 @@ def forward(self, x):
41
40
return self .network (x )
42
41
43
42
44
- class DQNFEPAgent (BaseAgent ):
43
+ class GenericAgent (BaseAgent ):
45
44
# FEP-related parameters
46
45
max_kl : float = Field (default = 10.0 )
47
46
max_fe : float = Field (default = 100.0 )
@@ -86,12 +85,10 @@ class DQNFEPAgent(BaseAgent):
86
85
def __init__ (self , state_dim : int , action_dim : int , ** kwargs ):
87
86
super ().__init__ (state_dim = state_dim , action_dim = action_dim , ** kwargs )
88
87
89
- # Initialize root belief with correct dimensions
90
88
self .root_belief = BeliefNode (
91
89
mean = np .zeros (state_dim ), precision = np .eye (state_dim ) * 0.1
92
90
)
93
91
94
- # Initialize DQN components
95
92
self .q_network = self ._build_network ()
96
93
self .target_network = self ._build_network ()
97
94
self .target_network .load_state_dict (self .q_network .state_dict ())
@@ -127,11 +124,9 @@ def learn(
127
124
batch = self .replay_buffer .sample (self .batch_size )
128
125
states , actions , rewards , next_states , dones = zip (* batch )
129
126
130
- # Convert lists of numpy arrays to single numpy arrays
131
127
states = np .array (states )
132
128
next_states = np .array (next_states )
133
129
134
- # Convert numpy arrays to tensors
135
130
states = torch .FloatTensor (states ).to (self .device )
136
131
actions = torch .LongTensor (
137
132
[self .action_space .index (action ) for action in actions ]
@@ -153,20 +148,32 @@ def learn(
153
148
self .soft_update_target_network ()
154
149
self .decay_exploration ()
155
150
151
+ def interpret_action (self , action : str ) -> str :
152
+ """
153
+ Interpret the agent's action in a human-readable format.
154
+ """
155
+ action_interpretations = {
156
+ "ask_question" : "The agent decides to ask a question to gather more information." ,
157
+ "provide_information" : "The agent provides relevant information to the user." ,
158
+ "clarify" : "The agent attempts to clarify a point or resolve any confusion." ,
159
+ "suggest_action" : "The agent suggests a specific action or solution to the user." ,
160
+ "express_empathy" : "The agent expresses empathy or understanding towards the user's situation." ,
161
+ "end_conversation" : "The agent determines it's appropriate to end the conversation." ,
162
+ }
163
+ return action_interpretations .get (action , f"Unknown action: { action } " )
164
+
156
165
def update_belief (self , observation : np .ndarray ) -> None :
157
- # Validate that observation is numeric
158
166
if not np .issubdtype (observation .dtype , np .number ):
159
167
raise ValueError ("Observation must be a numeric array." )
160
-
168
+
161
169
self ._update_belief_recursive (self .root_belief , observation )
162
170
self ._regularize_beliefs ()
163
171
164
172
def _update_belief_recursive (self , node : BeliefNode , observation : np .ndarray ):
165
- # Ensure observation is a numpy array of floats
166
173
observation = np .asarray (observation )
167
174
if observation .dtype != node .mean .dtype :
168
175
observation = observation .astype (node .mean .dtype )
169
-
176
+
170
177
prediction_error = observation - node .mean
171
178
node .precision += (
172
179
np .outer (prediction_error , prediction_error ) * self .learning_rate
@@ -190,13 +197,7 @@ def update_free_energy(self):
190
197
self .free_energy = self ._calculate_free_energy_recursive (self .root_belief )
191
198
192
199
def _build_network (self ):
193
- return nn .Sequential (
194
- nn .Linear (self .state_dim , 128 ),
195
- nn .ReLU (),
196
- nn .Linear (128 , 128 ),
197
- nn .ReLU (),
198
- nn .Linear (128 , self .action_dim ),
199
- )
200
+ return DQN (self .state_dim , self .action_dim ).to (self .device )
200
201
201
202
def _calculate_free_energy_recursive (self , node : BeliefNode ) -> float :
202
203
kl_divergence = self ._kl_divergence (node )
@@ -272,48 +273,20 @@ def _build_belief_hierarchy(self, node: BeliefNode, level: int):
272
273
node .children [action ] = child
273
274
self ._build_belief_hierarchy (child , level + 1 )
274
275
275
- def interpret_action (self , action : str ) -> str :
276
- """
277
- Interpret the agent's action in a human-readable format.
278
- """
279
- action_interpretations = {
280
- "ask_question" : "The agent decides to ask a question to gather more information." ,
281
- "provide_information" : "The agent provides relevant information to the user." ,
282
- "clarify" : "The agent attempts to clarify a point or resolve any confusion." ,
283
- "suggest_action" : "The agent suggests a specific action or solution to the user." ,
284
- "express_empathy" : "The agent expresses empathy or understanding towards the user's situation." ,
285
- "end_conversation" : "The agent determines it's appropriate to end the conversation." ,
286
- }
287
- return action_interpretations .get (action , f"Unknown action: { action } " )
288
-
289
276
def process_user_input (self , user_input : str ) -> np .ndarray :
290
- """
291
- Simple natural language processing to extract features from user input.
292
- """
293
- # This is a very basic implementation and can be expanded with more sophisticated NLP techniques
294
- features = np .zeros (5 ) # Assuming 5 features for simplicity
277
+ features = np .zeros (5 )
295
278
296
279
words = user_input .split ()
297
- features [0 ] = len (words ) # Number of words
298
- features [1 ] = user_input .count ("?" ) / len (words ) # Question mark ratio
299
- features [2 ] = user_input .count ("!" ) / len (words ) # Exclamation mark ratio
300
- features [3 ] = len (user_input ) / 100 # Normalized length of input
280
+ features [0 ] = len (words )
281
+ features [1 ] = user_input .count ("?" ) / len (words )
282
+ features [2 ] = user_input .count ("!" ) / len (words )
283
+ features [3 ] = len (user_input ) / 100
301
284
features [4 ] = sum (
302
285
1
303
286
for word in words
304
287
if word .lower () in ["please" , "thank" , "thanks" , "appreciate" ]
305
- ) / len (
306
- words
307
- ) # Politeness ratio
288
+ ) / len (words )
308
289
309
- # Ensure features are of type float
310
290
features = features .astype (float )
311
291
312
- # Debug print statements
313
- print (f"Debug: Input string: '{ user_input } '" )
314
- print (f"Debug: Word count: { len (words )} " )
315
- print (f"Debug: Question mark count: { user_input .count ('?' )} " )
316
- print (f"Debug: Exclamation mark count: { user_input .count ('!' )} " )
317
- print (f"Debug: Features: { features } " )
318
-
319
292
return features
0 commit comments