4
4
from active_inference_forager .agents .generic_agent import GenericAgent , ExperienceReplayBuffer , DQN
5
5
6
6
7
- @pytest .fixture
8
- def agent ():
9
- action_space = [
10
- "ask_question" ,
11
- "provide_information" ,
12
- "clarify" ,
13
- "suggest_action" ,
14
- "express_empathy" ,
15
- "end_conversation" ,
16
- ]
17
7
@pytest .fixture
18
8
def agent ():
19
9
action_space = [
@@ -112,10 +102,10 @@ def test_process_user_input(agent):
112
102
assert isinstance (features , np .ndarray )
113
103
assert features .shape == (17 ,)
114
104
assert features [0 ] == 8 # Word count
115
- assert features [1 ] > 0 # Average word length
116
- assert features [2 ] == 1 / 8 # Question mark frequency
117
- assert features [3 ] == 1 / 8 # Exclamation mark frequency
118
- assert - 1 <= features [4 ] <= 1 # Sentiment polarity
105
+ assert features [2 ] == 0.125 # Question mark frequency
106
+ assert features [3 ] == 0.125 # Exclamation mark frequency
107
+ assert 0 < features [1 ] < 5 # Average word length
108
+ assert - 1.0 <= features [4 ] <= 1.0 # Sentiment polarity
119
109
assert 0 <= features [5 ] <= 1 # Subjectivity
120
110
assert 0 <= features [6 ] <= 1 # Keyword detection
121
111
assert 0 <= features [7 ] <= 1 # Lexical diversity
@@ -127,16 +117,7 @@ def test_process_user_input(agent):
127
117
assert 0 <= features [13 ] <= 1 # Main clause density
128
118
assert features [14 ] > 0 # Average parse tree depth
129
119
assert 0 <= features [15 ] <= 1 # Stop word density
130
- def test_process_user_input (agent ):
131
- user_input = "Hello, how are you? I'm feeling great today!"
132
- features = agent .process_user_input (user_input )
133
- assert isinstance (features , np .ndarray )
134
- assert features .shape == (17 ,)
135
- assert features [0 ] == 8 # Word count
136
- assert features [2 ] == 0.125 # Question mark frequency
137
- assert features [3 ] == 0.125 # Exclamation mark frequency
138
- assert 0 < features [1 ] < 5 # Average word length
139
- assert - 0.9 < features [4 ] < 0.9 # Sentiment polarity
120
+ assert 0 <= features [16 ] <= 1 # Punctuation density
140
121
141
122
142
123
if __name__ == "__main__" :
0 commit comments