6
6
from pydantic import Field , ConfigDict
7
7
from collections import deque
8
8
import random
9
+ import re
10
+ from textblob import TextBlob
11
+ import spacy
9
12
10
13
from active_inference_forager .agents .base_agent import BaseAgent
11
14
from active_inference_forager .agents .belief_node import BeliefNode
@@ -80,9 +83,13 @@ class GenericAgent(BaseAgent):
80
83
episode_lengths : List [int ] = Field (default_factory = list )
81
84
total_steps : int = Field (default = 0 )
82
85
86
+ # NLP model
87
+ nlp : spacy .language .Language = Field (default_factory = lambda : spacy .load ("en_core_web_sm" ))
88
+
83
89
model_config = ConfigDict (arbitrary_types_allowed = True )
84
90
85
- def __init__ (self , state_dim : int , action_dim : int , ** kwargs ):
91
+ def __init__ (self , action_dim : int , ** kwargs ):
92
+ state_dim = 17 # Updated to match the environment's state dimension
86
93
super ().__init__ (state_dim = state_dim , action_dim = action_dim , ** kwargs )
87
94
88
95
self .root_belief = BeliefNode (
@@ -94,12 +101,19 @@ def __init__(self, state_dim: int, action_dim: int, **kwargs):
94
101
self .target_network .load_state_dict (self .q_network .state_dict ())
95
102
self .optimizer = optim .Adam (self .q_network .parameters (), lr = self .learning_rate )
96
103
97
- self .initialize_belief_and_action_space ()
104
+ self .action_space = [
105
+ "ask_question" ,
106
+ "provide_information" ,
107
+ "clarify" ,
108
+ "suggest_action" ,
109
+ "express_empathy" ,
110
+ "end_conversation" ,
111
+ ]
98
112
self .exploration_rate = self .epsilon_start
99
113
100
114
def take_action (self , state : np .ndarray ) -> str :
101
115
if np .random .rand () < self .exploration_rate :
102
- return np . random .choice (self .action_space )
116
+ return random .choice (self .action_space )
103
117
state_tensor = torch .FloatTensor (state ).unsqueeze (0 ).to (self .device )
104
118
q_values = self .q_network (state_tensor )
105
119
return self .action_space [q_values .argmax ().item ()]
@@ -274,19 +288,57 @@ def _build_belief_hierarchy(self, node: BeliefNode, level: int):
274
288
self ._build_belief_hierarchy (child , level + 1 )
275
289
276
290
def process_user_input (self , user_input : str ) -> np .ndarray :
277
- features = np .zeros (5 )
291
+ features = np .zeros (17 ) # Updated to match the environment's state dimension
278
292
293
+ # Basic text statistics
279
294
words = user_input .split ()
280
- features [0 ] = len (words )
281
- features [1 ] = user_input .count ("?" ) / len (words )
282
- features [2 ] = user_input .count ("!" ) / len (words )
283
- features [3 ] = len (user_input ) / 100
284
- features [4 ] = sum (
285
- 1
286
- for word in words
287
- if word .lower () in ["please" , "thank" , "thanks" , "appreciate" ]
288
- ) / len (words )
289
-
290
- features = features .astype (float )
295
+ features [0 ] = len (words ) # Word count
296
+ features [1 ] = sum (len (word ) for word in words ) / max (len (words ), 1 ) # Average word length
297
+ features [2 ] = user_input .count ("?" ) / max (len (words ), 1 ) # Question mark frequency
298
+ features [3 ] = user_input .count ("!" ) / max (len (words ), 1 ) # Exclamation mark frequency
299
+
300
+ # Sentiment analysis
301
+ blob = TextBlob (user_input )
302
+ features [4 ] = blob .sentiment .polarity # Sentiment polarity (-1 to 1)
303
+ features [5 ] = blob .sentiment .subjectivity # Subjectivity (0 to 1)
304
+
305
+ # Keyword detection
306
+ keywords = ["help" , "explain" , "understand" , "confused" , "clarify" ]
307
+ features [6 ] = sum (word .lower () in keywords for word in words ) / max (len (words ), 1 )
308
+
309
+ # Complexity indicators
310
+ features [7 ] = len (set (words )) / max (len (words ), 1 ) # Lexical diversity
311
+ features [8 ] = sum (len (word ) > 6 for word in words ) / max (len (words ), 1 ) # Proportion of long words
312
+
313
+ # Politeness indicator
314
+ polite_words = ["please" , "thank" , "thanks" , "appreciate" , "kindly" ]
315
+ features [9 ] = sum (word .lower () in polite_words for word in words ) / max (len (words ), 1 )
316
+
317
+ # spaCy processing
318
+ doc = self .nlp (user_input )
319
+
320
+ # Named Entity Recognition
321
+ features [10 ] = len (doc .ents ) / max (len (words ), 1 ) # Named entity density
322
+
323
+ # Part-of-speech tagging
324
+ pos_counts = {pos : 0 for pos in ['NOUN' , 'VERB' , 'ADJ' , 'ADV' ]}
325
+ for token in doc :
326
+ if token .pos_ in pos_counts :
327
+ pos_counts [token .pos_ ] += 1
328
+ features [11 ] = pos_counts ['NOUN' ] / max (len (words ), 1 ) # Noun density
329
+ features [12 ] = pos_counts ['VERB' ] / max (len (words ), 1 ) # Verb density
330
+
331
+ # Dependency parsing
332
+ features [13 ] = len ([token for token in doc if token .dep_ == 'ROOT' ]) / max (len (words ), 1 ) # Main clause density
333
+
334
+ # Sentence complexity (using dependency parse tree depth)
335
+ def tree_depth (token ):
336
+ return 1 + max ((tree_depth (child ) for child in token .children ), default = 0 )
337
+
338
+ features [14 ] = sum (tree_depth (sent .root ) for sent in doc .sents ) / max (len (list (doc .sents )), 1 ) # Average parse tree depth
339
+
340
+ # Additional features to match the environment's state dimension
341
+ features [15 ] = len ([token for token in doc if token .is_stop ]) / max (len (words ), 1 ) # Stop word density
342
+ features [16 ] = len ([token for token in doc if token .is_punct ]) / max (len (words ), 1 ) # Punctuation density
291
343
292
344
return features
0 commit comments