Skip to content

Commit dc23734

Browse files
Refactor feature calculation in GenericAgent
1 parent a55038e commit dc23734

File tree

2 files changed

+9
-28
lines changed

2 files changed

+9
-28
lines changed

src/active_inference_forager/agents/generic_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def process_user_input(self, user_input: str) -> np.ndarray:
293293
# Basic text statistics
294294
words = user_input.split()
295295
features[0] = len(words) # Word count
296-
features[1] = len(user_input) / max(len(words), 1) # Average word length
296+
features[1] = sum(len(word) for word in words) / max(len(words), 1) # Average word length
297297
features[2] = user_input.count("?") / max(len(words), 1) # Question mark frequency
298298
features[3] = user_input.count("!") / max(len(words), 1) # Exclamation mark frequency
299299

@@ -304,15 +304,15 @@ def process_user_input(self, user_input: str) -> np.ndarray:
304304

305305
# Keyword detection
306306
keywords = ["help", "explain", "understand", "confused", "clarify"]
307-
features[6] = sum(bool(word.lower() in keywords)
307+
features[6] = sum(word.lower() in keywords for word in words) / max(len(words), 1)
308308

309309
# Complexity indicators
310310
features[7] = len(set(words)) / max(len(words), 1) # Lexical diversity
311311
features[8] = sum(len(word) > 6 for word in words) / max(len(words), 1) # Proportion of long words
312312

313313
# Politeness indicator
314314
polite_words = ["please", "thank", "thanks", "appreciate", "kindly"]
315-
features[9] = sum(bool(word.lower() in polite_words)
315+
features[9] = sum(word.lower() in polite_words for word in words) / max(len(words), 1)
316316

317317
# spaCy processing
318318
doc = self.nlp(user_input)
@@ -341,4 +341,4 @@ def tree_depth(token):
341341
features[15] = len([token for token in doc if token.is_stop]) / max(len(words), 1) # Stop word density
342342
features[16] = len([token for token in doc if token.is_punct]) / max(len(words), 1) # Punctuation density
343343

344-
return features.astype(float)
344+
return features

tests/unit/test_dqn_fep_agent.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,6 @@
44
from active_inference_forager.agents.generic_agent import GenericAgent, ExperienceReplayBuffer, DQN
55

66

7-
@pytest.fixture
8-
def agent():
9-
action_space = [
10-
"ask_question",
11-
"provide_information",
12-
"clarify",
13-
"suggest_action",
14-
"express_empathy",
15-
"end_conversation",
16-
]
177
@pytest.fixture
188
def agent():
199
action_space = [
@@ -112,10 +102,10 @@ def test_process_user_input(agent):
112102
assert isinstance(features, np.ndarray)
113103
assert features.shape == (17,)
114104
assert features[0] == 8 # Word count
115-
assert features[1] > 0 # Average word length
116-
assert features[2] == 1 / 8 # Question mark frequency
117-
assert features[3] == 1 / 8 # Exclamation mark frequency
118-
assert -1 <= features[4] <= 1 # Sentiment polarity
105+
assert features[2] == 0.125 # Question mark frequency
106+
assert features[3] == 0.125 # Exclamation mark frequency
107+
assert 0 < features[1] < 5 # Average word length
108+
assert -1.0 <= features[4] <= 1.0 # Sentiment polarity
119109
assert 0 <= features[5] <= 1 # Subjectivity
120110
assert 0 <= features[6] <= 1 # Keyword detection
121111
assert 0 <= features[7] <= 1 # Lexical diversity
@@ -127,16 +117,7 @@ def test_process_user_input(agent):
127117
assert 0 <= features[13] <= 1 # Main clause density
128118
assert features[14] > 0 # Average parse tree depth
129119
assert 0 <= features[15] <= 1 # Stop word density
130-
def test_process_user_input(agent):
131-
user_input = "Hello, how are you? I'm feeling great today!"
132-
features = agent.process_user_input(user_input)
133-
assert isinstance(features, np.ndarray)
134-
assert features.shape == (17,)
135-
assert features[0] == 8 # Word count
136-
assert features[2] == 0.125 # Question mark frequency
137-
assert features[3] == 0.125 # Exclamation mark frequency
138-
assert 0 < features[1] < 5 # Average word length
139-
assert -0.9 < features[4] < 0.9 # Sentiment polarity
120+
assert 0 <= features[16] <= 1 # Punctuation density
140121

141122

142123
if __name__ == "__main__":

0 commit comments

Comments
 (0)